#!/usr/bin/python3 """Quickselect test program. This is adapted from a Codility exercise. """ import random # just for testing def median(A): "Find the median of nonempty iterable A in linear (?) time." A = list(A) median_idx = len(A) // 2 quickselect(A, 0, len(A), median_idx) return A[median_idx] def quickselect(items, start, end, rank): "Rearrange items[start:end] so that items[rank] is in sorted position." while end > start + 1: # Don't use recursion because Python may bomb out i = partition(items, start, end) if i == rank: return if i < rank: start, end = i + 1, end else: start, end = start, i def partition(items, start, end): """Quicksort-style partition function for items[start:end], end > start. Chooses a partitioning element and partitions the range around it; returns the index i of the partitioning element. Once the function is done, items[start:i] contains all the items (and only the items) from items[start:end] that are less than (or equal to?) items[i], and items[i+1:end] contains all the items (and only the items) from items[start:end] that are greater than items[i]. This is Lomuto’s algorithm. Within the loop, it uses items[end-1] as the partition. At the beginning of each iteration of the loop, the subrange is divided into three parts: items[start:i] (initially empty) are all less than or equal to items[end-1], items[i:j] (also initially empty) are all greater than items[end-1], and items[j:end] (initially the whole subrange) haven’t yet been examined. """ # Use the middle element for the partition element to avoid the # O(N²) worst case on already-sorted or reverse-sorted input. You # can still hit the O(N²) worst case on sufficiently pathological # input, but it’s less likely to happen by accident. middle = start + (end-start) // 2 items[end-1], items[middle] = items[middle], items[end-1] i = start for j in range(start, end): if items[j] <= items[end-1]: items[j], items[i], i = items[i], items[j], i + 1 return i - 1 def test_median(items): "Simple test for median." m = median(items) expected = sorted(items)[len(items)//2] assert m == expected, (items, m, expected) def test_n_medians(n, m, p): "Simple generative test for median: n lists of k < m numbers q < p." for i in range(n): test_median([random.randrange(p) for j in range(random.randrange(1, m+1))])