Nth element (Java)

From LiteratePrograms

Jump to: navigation, search
Other implementations: Java | Python

Find the nth-largest element in an unsorted list.

theory

The obvious way to find the nth-largest element in a list is to sort it first, then examine the nth index. However, that approach does more work than is required, as we only care about the element in the nth place, not the permutation of any elements before or after. By modifying a Quicksort to only recurse on the sublist of interest, we avoid much of the work of a full sort.

(for an alternative approach, see An efficient implementation of Blum, Floyd, Pratt, Rivest, and Tarjan’s worst-case linear selection algorithm)

Question: Suppose n to be the minimum or maximum element. What relation does this algorithm bear to a 1-dimensional version of Quickhull (Python, arrays)?

practice

We start off in classic quicksort style — we choose a pivot, and create lists of elements which sort above or below the pivot.

<<define qnth>>=
public static <T extends Comparable<? super T>> T qnth(List<T> sample, int n) {
    T pivot = sample.get(0);
    List<T> below = new ArrayList<T>(),
            above = new ArrayList<T>();
    for (T s : sample) {
        if (s.compareTo(pivot) < 0)
            below.add(s);
        else if (s.compareTo(pivot) > 0)
            above.add(s);
    }
    int i = below.size(),
        j = sample.size() - above.size();

At this point, i and j would (if it were sorted) index into sample as follows:

hence we need only recurse on the segment containing the nth element.

<<define qnth>>=
    if (n < i)       return qnth(below, n);
    else if (n >= j) return qnth(above, n-j);
    else             return pivot;
}

wrapping up

Finally we wrap the function up in a class with a main function which, if run from the command line, checks (for an element from the middle of a random list) that qnth produces the same result as sorting the list and then selecting the nth element.

<<QNth.java>>=
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
public class QNth {
    define qnth
    public static void main(String[] args) {
        int n = 64, mid = 32;
        List<Double> sample = new ArrayList<Double>();
        for (int i = 0; i < n; i++)
            sample.add(Math.random());
        double partial = qnth(sample, mid);
        Collections.sort(sample);
        double sorted = sample.get(mid);
        System.out.println("" + partial + " " + sorted + " " + (partial == sorted));
    }
}

The output should be similar to:

0.579906140785445 0.579906140785445 true
Download code
Views