handle k-nearest neighbours search

update DocstrumBB with kd-tree
This commit is contained in:
BobLd
2020-03-10 13:36:44 +00:00
committed by Eliot Jones
parent 8cafda3577
commit 5b8a2f2e38
6 changed files with 296 additions and 99 deletions

View File

@@ -13,6 +13,7 @@
{
/// <summary>
/// Algorithm to group elements using nearest neighbours.
/// <para>Uses the nearest neighbour as candidate.</para>
/// </summary>
/// <typeparam name="T">Letter, Word, TextLine, etc.</typeparam>
/// <param name="elements">Elements to group.</param>
@@ -61,7 +62,7 @@
if (filterPivot(pivot))
{
var paired = kdTree.FindNearestNeighbours(pivot, pivotPoint, distMeasure, out int index, out double dist);
var paired = kdTree.FindNearestNeighbour(pivot, pivotPoint, distMeasure, out int index, out double dist);
if (index != -1)
{
@@ -77,6 +78,77 @@
return GroupIndexes(indexes);
}
/// <summary>
/// Algorithm to group elements using nearest neighbours.
/// <para>Uses the k-nearest neighbours as candidates.</para>
/// </summary>
/// <typeparam name="T">Letter, Word, TextLine, etc.</typeparam>
/// <param name="elements">Elements to group.</param>
/// <param name="k">The k-nearest neighbours to consider as candidates.</param>
/// <param name="distMeasure">The distance measure between two points.</param>
/// <param name="maxDistanceFunction">The function that determines the maximum distance between two points in the same cluster.</param>
/// <param name="pivotPoint">The pivot's point to use for pairing, e.g. BottomLeft, TopLeft.</param>
/// <param name="candidatesPoint">The candidates' point to use for pairing, e.g. BottomLeft, TopLeft.</param>
/// <param name="filterPivot">Filter to apply to the pivot point. If false, point will not be paired at all, e.g. is white space.</param>
/// <param name="filterFinal">Filter to apply to both the pivot and the paired point. If false, point will not be paired at all, e.g. pivot and paired point have same font.</param>
/// <param name="maxDegreeOfParallelism">Sets the maximum number of concurrent tasks enabled.
/// <para>A positive property value limits the number of concurrent operations to the set value.
/// If it is -1, there is no limit on the number of concurrently running operations.</para></param>
internal static IEnumerable<HashSet<int>> ClusterNearestNeighbours<T>(IReadOnlyList<T> elements, int k,
Func<PdfPoint, PdfPoint, double> distMeasure,
Func<T, T, double> maxDistanceFunction,
Func<T, PdfPoint> pivotPoint, Func<T, PdfPoint> candidatesPoint,
Func<T, bool> filterPivot, Func<T, T, bool> filterFinal,
int maxDegreeOfParallelism)
{
/*************************************************************************************
* Algorithm steps
* 1. Find nearest neighbours indexes (done in parallel)
* Iterate every point (pivot) and put its nearest neighbour's index in an array
* e.g. if nearest neighbour of point i is point j, then indexes[i] = j.
* Only conciders a neighbour if it is within the maximum distance.
* If not within the maximum distance, index will be set to -1.
* Each element has only one connected neighbour.
* NB: Given the possible asymmetry in the relationship, it is possible
* that if indexes[i] = j then indexes[j] != i.
*
* 2. Group indexes
* Group indexes if share neighbours in common - Depth-first search
* e.g. if we have indexes[i] = j, indexes[j] = k, indexes[m] = n and indexes[n] = -1
* (i,j,k) will form a group and (m,n) will form another group.
*************************************************************************************/
int[] indexes = Enumerable.Repeat(-1, elements.Count).ToArray();
KdTree<T> kdTree = new KdTree<T>(elements, candidatesPoint);
ParallelOptions parallelOptions = new ParallelOptions() { MaxDegreeOfParallelism = maxDegreeOfParallelism };
// 1. Find nearest neighbours indexes
Parallel.For(0, elements.Count, parallelOptions, e =>
{
var pivot = elements[e];
if (filterPivot(pivot))
{
var paired = kdTree.FindNearestNeighbours(pivot, k, pivotPoint, distMeasure);
foreach (var c in paired)
{
var filter = filterFinal(pivot, c.Item1);
var maxDist = maxDistanceFunction(pivot, c.Item1);
if (filter && c.Item3 < maxDist)
{
indexes[e] = c.Item2;
break;
}
}
}
});
// 2. Group indexes
return GroupIndexes(indexes);
}
/// <summary>
/// Algorithm to group elements using nearest neighbours.
/// </summary>

View File

@@ -14,13 +14,20 @@
public PdfPoint FindNearestNeighbours(PdfPoint pivot, Func<PdfPoint, PdfPoint, double> distanceMeasure, out int index, out double distance)
{
return FindNearestNeighbours(pivot, p => p, distanceMeasure, out index, out distance);
return FindNearestNeighbour(pivot, p => p, distanceMeasure, out index, out distance);
}
public IReadOnlyList<(PdfPoint, int, double)> FindNearestNeighbours(PdfPoint pivot, int k, Func<PdfPoint, PdfPoint, double> distanceMeasure)
{
return FindNearestNeighbours(pivot, k, p => p, distanceMeasure);
}
}
internal class KdTree<T>
{
private KdTreeNode<T> Root;
private readonly KdTreeNode<T> Root;
public readonly int Count;
public KdTree(IReadOnlyList<T> candidates, Func<T, PdfPoint> candidatesPointFunc)
{
@@ -29,6 +36,7 @@
throw new ArgumentException("KdTree(): candidates cannot be null or empty.", nameof(candidates));
}
Count = candidates.Count;
Root = BuildTree(Enumerable.Range(0, candidates.Count).Zip(candidates, (e, p) => (e, candidatesPointFunc(p), p)).ToArray(), 0);
}
@@ -67,23 +75,23 @@
#region NN
/// <summary>
///
/// Get the nearest neighbour to the pivot element.
/// </summary>
/// <param name="pivot"></param>
/// <param name="pivot">The element for which to find the nearest neighbour.</param>
/// <param name="pivotPointFunc"></param>
/// <param name="distanceMeasure"></param>
/// <param name="index">The nearest neighbour's index (returns -1 if not found).</param>
/// <param name="distance">The distance between the pivot and the nearest neighbour (returns <see cref="double.NaN"/> if not found).</param>
/// <returns>The nearest neighbour's element.</returns>
public T FindNearestNeighbours(T pivot, Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distanceMeasure, out int index, out double distance)
public T FindNearestNeighbour(T pivot, Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distanceMeasure, out int index, out double distance)
{
var result = FindNearestNeighbours(Root, pivot, pivotPointFunc, distanceMeasure);
var result = FindNearestNeighbour(Root, pivot, pivotPointFunc, distanceMeasure);
index = result.Item1 != null ? result.Item1.Index : -1;
distance = result.Item2.HasValue ? result.Item2.Value : double.NaN;
distance = result.Item2 ?? double.NaN;
return result.Item1 != null ? result.Item1.Element : default;
}
private static (KdTreeNode<T>, double?) FindNearestNeighbours(KdTreeNode<T> node, T pivot, Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distance)
private static (KdTreeNode<T>, double?) FindNearestNeighbour(KdTreeNode<T> node, T pivot, Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distance)
{
if (node == null)
{
@@ -111,7 +119,7 @@
if (pointValue < node.L)
{
// start left
(newNode, newDist) = FindNearestNeighbours(node.LeftChild, pivot, pivotPointFunc, distance);
(newNode, newDist) = FindNearestNeighbour(node.LeftChild, pivot, pivotPointFunc, distance);
if (newDist.HasValue && newDist <= currentDistance && !newNode.Element.Equals(pivot))
{
@@ -121,13 +129,13 @@
if (node.RightChild != null && pointValue + currentDistance >= node.L)
{
(newNode, newDist) = FindNearestNeighbours(node.RightChild, pivot, pivotPointFunc, distance);
(newNode, newDist) = FindNearestNeighbour(node.RightChild, pivot, pivotPointFunc, distance);
}
}
else
{
// start right
(newNode, newDist) = FindNearestNeighbours(node.RightChild, pivot, pivotPointFunc, distance);
(newNode, newDist) = FindNearestNeighbour(node.RightChild, pivot, pivotPointFunc, distance);
if (newDist.HasValue && newDist <= currentDistance && !newNode.Element.Equals(pivot))
{
@@ -137,7 +145,7 @@
if (node.LeftChild != null && pointValue - currentDistance <= node.L)
{
(newNode, newDist) = FindNearestNeighbours(node.LeftChild, pivot, pivotPointFunc, distance);
(newNode, newDist) = FindNearestNeighbour(node.LeftChild, pivot, pivotPointFunc, distance);
}
}
@@ -152,6 +160,170 @@
}
#endregion
#region k-NN
/*****************************************************************************
* WARNING: k-nearest neighbours algo will need more checks and tests.
*****************************************************************************/
/// <summary>
/// Get the k nearest neighbours to the pivot element. If elements are equidistant, they are counted as one.
/// </summary>
/// <param name="pivot">The element for which to find the k nearest neighbours.</param>
/// <param name="k">The number of neighbours to return. If elements are equidistant, they are counted as one.</param>
/// <param name="pivotPointFunc"></param>
/// <param name="distanceMeasure"></param>
/// <returns>Returns a list of tuples of the k nearest neighbours. Tuples are (element, index, distance).</returns>
public IReadOnlyList<(T, int, double)> FindNearestNeighbours(T pivot, int k, Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distanceMeasure)
{
if (k == 1)
{
// if only 1 neighbour required, use default to avoid creating KNearestNeighboursQueue
var nn = FindNearestNeighbour(pivot, pivotPointFunc, distanceMeasure, out int index, out double distance);
if (index == -1)
{
return EmptyArray<(T, int, double)>.Instance;
}
return new List<(T, int, double)>() { (nn, index, distance) };
}
else
{
var kdTreeNodes = new KNearestNeighboursQueue(k);
FindNearestNeighbours(Root, pivot, k, pivotPointFunc, distanceMeasure, kdTreeNodes);
return kdTreeNodes.SelectMany(n => n.Value.Select(e => (e.Element, e.Index, n.Key))).ToList();
}
}
private static (KdTreeNode<T>, double) FindNearestNeighbours(KdTreeNode<T> node, T pivot, int k,
Func<T, PdfPoint> pivotPointFunc, Func<PdfPoint, PdfPoint, double> distance, KNearestNeighboursQueue queue)
{
if (node == null)
{
return (null, double.NaN);
}
else if (node.IsLeaf)
{
if (node.Element.Equals(pivot))
{
return (null, double.NaN);
}
var currentDistance = distance(node.Value, pivotPointFunc(pivot));
var currentNearestNode = node;
if (!queue.IsFull || currentDistance <= queue.LastDistance)
{
queue.Add(currentDistance, currentNearestNode);
currentDistance = queue.LastDistance;
currentNearestNode = queue.LastElement;
}
return (currentNearestNode, currentDistance);
}
else
{
var point = pivotPointFunc(pivot);
var currentNearestNode = node;
var currentDistance = distance(node.Value, point);
if (!queue.IsFull || currentDistance <= queue.LastDistance)
{
queue.Add(currentDistance, currentNearestNode);
currentDistance = queue.LastDistance;
currentNearestNode = queue.LastElement;
}
KdTreeNode<T> newNode = null;
double newDist = double.NaN;
var pointValue = node.IsAxisCutX ? point.X : point.Y;
if (pointValue < node.L)
{
// start left
(newNode, newDist) = FindNearestNeighbours(node.LeftChild, pivot, k, pivotPointFunc, distance, queue);
if (!double.IsNaN(newDist) && newDist <= currentDistance && !newNode.Element.Equals(pivot))
{
queue.Add(newDist, newNode);
currentDistance = queue.LastDistance;
currentNearestNode = queue.LastElement;
}
if (node.RightChild != null && pointValue + currentDistance >= node.L)
{
(newNode, newDist) = FindNearestNeighbours(node.RightChild, pivot, k, pivotPointFunc, distance, queue);
}
}
else
{
// start right
(newNode, newDist) = FindNearestNeighbours(node.RightChild, pivot, k, pivotPointFunc, distance, queue);
if (!double.IsNaN(newDist) && newDist <= currentDistance && !newNode.Element.Equals(pivot))
{
queue.Add(newDist, newNode);
currentDistance = queue.LastDistance;
currentNearestNode = queue.LastElement;
}
if (node.LeftChild != null && pointValue - currentDistance <= node.L)
{
(newNode, newDist) = FindNearestNeighbours(node.LeftChild, pivot, k, pivotPointFunc, distance, queue);
}
}
if (!double.IsNaN(newDist) && newDist <= currentDistance && !newNode.Element.Equals(pivot))
{
queue.Add(newDist, newNode);
currentDistance = queue.LastDistance;
currentNearestNode = queue.LastElement;
}
return (currentNearestNode, currentDistance);
}
}
private class KNearestNeighboursQueue : SortedList<double, HashSet<KdTreeNode<T>>>
{
public readonly int K;
public KdTreeNode<T> LastElement { get; private set; }
public double LastDistance { get; private set; }
public bool IsFull => Count >= K;
public KNearestNeighboursQueue(int k) : base(k)
{
K = k;
LastDistance = double.PositiveInfinity;
}
public void Add(double key, KdTreeNode<T> value)
{
if (key > LastDistance && IsFull)
{
return;
}
if (!ContainsKey(key))
{
base.Add(key, new HashSet<KdTreeNode<T>>());
if (Count > K)
{
RemoveAt(Count - 1);
}
}
if (this[key].Add(value))
{
var last = this.Last();
LastElement = last.Value.Last();
LastDistance = last.Key;
}
}
}
#endregion
private class KdTreeLeaf<Q> : KdTreeNode<Q>
{
public override bool IsLeaf => true;

View File

@@ -108,38 +108,39 @@
ParallelOptions parallelOptions = new ParallelOptions() { MaxDegreeOfParallelism = maxDegreeOfParallelism };
// 1. Estimate in line and between line spacing
// 1. Estimate within line and between line spacing
KdTree<Word> kdTreeWL = new KdTree<Word>(wordsList, w => w.BoundingBox.BottomLeft);
KdTree<Word> kdTreeBL = new KdTree<Word>(wordsList, w => w.BoundingBox.TopLeft);
Parallel.For(0, wordsList.Count, parallelOptions, i =>
{
var word = wordsList[i];
// Within-line distance
var pointsWithinLine = GetNearestPointDistance(wordsList, word,
bb => bb.BottomRight, bb => bb.BottomRight,
bb => bb.BottomLeft, bb => bb.BottomLeft,
withinLine, Distances.Horizontal);
if (pointsWithinLine != null)
var neighbourWL = kdTreeWL.FindNearestNeighbours(word, 2, w => w.BoundingBox.BottomRight, (p1, p2) => Distances.WeightedEuclidean(p1, p2, 0.5));
foreach (var n in neighbourWL)
{
withinLineDistList.Add(pointsWithinLine.Value);
if (withinLine.Contains(Distances.Angle(word.BoundingBox.BottomRight, n.Item1.BoundingBox.BottomLeft)))
{
withinLineDistList.Add(Distances.Horizontal(word.BoundingBox.BottomRight, n.Item1.BoundingBox.BottomLeft));
}
}
// Between-line distance
var pointsBetweenLine = GetNearestPointDistance(wordsList, word,
bb => bb.BottomLeft, bb => bb.Centroid,
bb => bb.TopLeft, bb => bb.Centroid,
betweenLine, Distances.Vertical);
if (pointsBetweenLine != null)
var neighbourBL = kdTreeBL.FindNearestNeighbours(word, 2, w => w.BoundingBox.BottomLeft, (p1, p2) => Distances.WeightedEuclidean(p1, p2, 50));
foreach (var n in neighbourBL)
{
betweenLineDistList.Add(pointsBetweenLine.Value);
if (betweenLine.Contains(Distances.Angle(word.BoundingBox.Centroid, n.Item1.BoundingBox.Centroid)))
{
betweenLineDistList.Add(Distances.Vertical(word.BoundingBox.BottomLeft, n.Item1.BoundingBox.TopLeft));
}
}
});
double? withinLineDistance = GetPeakAverageDistance(withinLineDistList);
double? betweenLineDistance = GetPeakAverageDistance(betweenLineDistList);
if (withinLineDistance == null || betweenLineDistance == null)
if (!withinLineDistance.HasValue || !betweenLineDistance.HasValue)
{
return new[] { new TextBlock(new[] { new TextLine(wordsList) }) };
}
@@ -193,68 +194,14 @@
return blocks.Where(b => b != null).ToList();
}
/// <summary>
/// Get information on the nearest point, filtered for angle.
/// </summary>
private double? GetNearestPointDistance(List<Word> words, Word pivot, Func<PdfRectangle,
PdfPoint> funcPivotDist, Func<PdfRectangle, PdfPoint> funcPivotAngle,
Func<PdfRectangle, PdfPoint> funcPointsDist, Func<PdfRectangle, PdfPoint> funcPointsAngle,
AngleBounds angleBounds,
Func<PdfPoint, PdfPoint, double> finalDistanceMeasure)
{
var pointR = funcPivotDist(pivot.BoundingBox);
var pivotPoint = funcPivotAngle(pivot.BoundingBox);
var wordsWithinAngleBoundDistancePoints = new List<PdfPoint>();
// Filter to words within the angle range.
foreach (var word in words)
{
// Ignore the pivot word.
if (ReferenceEquals(word, pivot))
{
continue;
}
var angle = Distances.Angle(pivotPoint, funcPointsAngle(word.BoundingBox));
if (angleBounds.Contains(angle))
{
wordsWithinAngleBoundDistancePoints.Add(funcPointsDist(word.BoundingBox));
}
}
if (wordsWithinAngleBoundDistancePoints.Count == 0)
{
return null;
}
var closestWordIndex = Distances.FindIndexNearest(pointR, wordsWithinAngleBoundDistancePoints, p => p,
p => p, Distances.Euclidean, out _);
if (closestWordIndex < 0 || closestWordIndex >= wordsWithinAngleBoundDistancePoints.Count)
{
return null;
}
return finalDistanceMeasure(pointR, wordsWithinAngleBoundDistancePoints[closestWordIndex]);
}
private static IEnumerable<TextLine> GetLines(List<Word> words, double maxDist, AngleBounds withinLine, int maxDegreeOfParallelism)
{
TextDirection textDirection = words[0].TextDirection;
var groupedIndexes = ClusteringAlgorithms.ClusterNearestNeighbours(words, Distances.Euclidean,
var groupedIndexes = ClusteringAlgorithms.ClusterNearestNeighbours(words, 2, Distances.Euclidean,
(pivot, candidate) => maxDist,
pivot => pivot.BoundingBox.BottomRight, candidate => candidate.BoundingBox.BottomLeft,
pivot => true,
(pivot, candidate) =>
{
// Compare bottom right with bottom left for angle
var withinLineAngle = Distances.Angle(pivot.BoundingBox.BottomRight, candidate.BoundingBox.BottomLeft);
return (withinLineAngle >= withinLine.Lower && withinLineAngle <= withinLine.Upper);
},
(pivot, candidate) => withinLine.Contains(Distances.Angle(pivot.BoundingBox.BottomRight, candidate.BoundingBox.BottomLeft)),
maxDegreeOfParallelism).ToList();
Func<IEnumerable<Word>, IReadOnlyList<Word>> orderFunc = l => l.OrderBy(x => x.BoundingBox.Left).ToList();
@@ -287,7 +234,7 @@
* If the two lines are not overlapping, the distance is set to the max distance.
**************************************************************************************************/
Func<PdfLine, PdfLine, double> euclidianOverlappingMiddleDistance = (l1, l2) =>
double euclidianOverlappingMiddleDistance(PdfLine l1, PdfLine l2)
{
var left = Math.Max(l1.Point1.X, l2.Point1.X);
var d = (Math.Min(l1.Point2.X, l2.Point2.X) - left);
@@ -297,7 +244,7 @@
return Distances.Euclidean(
new PdfPoint(left + d / 2, l1.Point1.Y),
new PdfPoint(left + d / 2, l2.Point1.Y));
};
}
var groupedIndexes = ClusteringAlgorithms.ClusterNearestNeighbours(lines,
euclidianOverlappingMiddleDistance,

View File

@@ -254,7 +254,7 @@
#region Sorted Queue
private class QueueEntries : SortedSet<QueueEntry>
{
int bound;
readonly int bound;
public QueueEntries(int maximumBound)
{

View File

@@ -31,14 +31,14 @@
/// <param name="letters">The letters in the page.</param>
public IEnumerable<Word> GetWords(IReadOnlyList<Letter> letters)
{
Func<Letter, Letter, double> baseMaxFunc = (l1, l2) =>
double baseMaxFunc(Letter l1, Letter l2)
{
return Math.Max(Math.Max(Math.Max(
Math.Abs(l1.GlyphRectangle.Width),
Math.Abs(l2.GlyphRectangle.Width)),
Math.Abs(l1.Width)),
Math.Abs(l2.Width));
};
}
List<Word> wordsH = GetWords(
letters.Where(l => l.TextDirection == TextDirection.Horizontal).ToList(),

View File

@@ -382,10 +382,16 @@
}
else
{
if (!rectangle.Normalise().IntersectsWith(other.Normalise()))
var r1 = rectangle.Normalise();
var r2 = other.Normalise();
if (Math.Abs(r1.Rotation) < epsilon && Math.Abs(r2.Rotation) < epsilon)
{
// check rotation to avoid stackoverflow
if (!r1.IntersectsWith(r2))
{
return false;
}
}
if (rectangle.Contains(other.BottomLeft)) return true;
if (rectangle.Contains(other.TopRight)) return true;