K closest points

Given an array containing N points find the K closest points to the origin in the 2D plane. You can assume K is much smaller than N and N is very large.

Notice the key requirement here: “K is much smaller than N. N is very large”. Definitely the brute-force solution by finding distance of all element and then sorting them in O(nlgn). But its costlier as we don’t need to sort all the n points as we are only concerned for first k points in the sorted list.

An efficient solution could be to use a Max-Heap of size k for maintaining K minimum distances. Keep adding the distance value until the heap is full. If it is full and the new point is less distant then the max distance of heap then replace the max element by the new distance in O(1). Then Max-heapify to update the heap in O(lgk). Once all the elements in the point array is traversed then make a second pass to the array and out put first k elements having distance less than the max distance of the heap (i.e. the first element). This is an O(nlgk) time and O(lgk) space using Heap.

Below is the O(nlgk) implementation using Java PriorityQueue (in reverse order) as Max Heap.

 

public static class Point implements Comparable<Point> {
    public double x;
    public double y;

    public Point(final double x, final double y) {
        this.x = x;
        this.y = y;
    }
    
    public double getDist(){
    	return x*x+y*y;
    }

	@Override
	public int compareTo(Point o) {
		int c = Double.compare(getDist(), o.getDist());
		if(c == 0){
			c = Double.compare(x, o.x);
			if(c == 0){
				c = Double.compare(y, o.y);
			}
		}
		
		return c;
	}

	@Override
	public String toString() {
		return "(" + x + "," + y + ")";
	}
}

public static Point[] closestk(final Point points[], final int k) {
    //max heap
    final PriorityQueue<Point> kClosest = new PriorityQueue<>(k, Collections.reverseOrder());

    for (int i = 0; i < points.length; i++) {
        if (kClosest.size() < k) {
        	kClosest.add(points[i]);
        } else if (points[i].getDist() < kClosest.peek().getDist()) {
            kClosest.remove();
            kClosest.add(points[i]);
        }
    }

    return kClosest.toArray(new Point[k]);
}

 

We actually do much better by using QuickSelect the order for selecting kth minimum distance from the list of distances in O(n) time. Then we need to go through the array of points and output first k elements with distance less than kth minimum distance. This is O(n) time in-place algorithm with constant space.

public static class Point {
    public double x;
    public double y;

    public Point(final double x, final double y) {
        this.x = x;
        this.y = y;
    }
}

public static double kthSmallest(final double[] A, final int p, final int r, final int k) {
    if (p < r) {
        final int q = RandomizedPartition(A, p, r);

        final int n = q - p + 1;
        if (k == n) {
            return A[q];
        } else if (k < n) {
            return kthSmallest(A, p, q - 1, k);
        } else {
            return kthSmallest(A, q + 1, r, k - n);
        }
    } else {
        return Double.MIN_VALUE;
    }
}

public static Point[] closestkWithOrderStatistics(final Point points[], final int k) {
    final int n = points.length;
    final double[] dist = new double[n];
    for (int i = 0; i < n; i++) {
        dist[i] = Math.sqrt(points[i].x * points[i].x + points[i].y * points[i].y);
    }
    final double kthMin = kthSmallest(dist, 0, n - 1, k);

    final Point[] result = new Point[k];
    for (int i = 0, j = 0; i < n && j < k; i++) {
        final double d = Math.sqrt(points[i].x * points[i].x + points[i].y * points[i].y);
        if (d <= kthMin) {
            result[j++] = points[i];
        }
    }

    return result;
}

private static void swap(final double input[], final int i, final int j) {
    final double temp = input[i];
    input[i] = input[j];
    input[j] = temp;
}

private static int partition(final double[] A, final int p, final int r) {
    final double pivot = A[r];
    int i = p - 1;
    int j = p;

    for (j = p; j < r; j++) {
        if (A[j] <= pivot) {
            swap(A, ++i, j);
        }
    }

    swap(A, i + 1, r);
    return i + 1;
}

private static int RandomizedPartition(final double[] A, final int p, final int r) {
    final int i = (int) Math.round(p + Math.random() * (r - p));
    swap(A, i, r);
    return partition(A, p, r);
}

Leave a Reply

Your email address will not be published. Required fields are marked *