sort 源码剖析

我们在日常对于数据的排序过程中,会接触到几种排序的算法,他们各自都有优势。那么如果把这四大排序合起来,充分利用各自的优点的话,那么会不会让整体的排序变得更加高效?

四大排序的优缺点汇总

为了能够更好的了解,如何将各个排序合起来使用。先来回顾一下,四大排序各有什么优缺点:

快速排序

快速排序的时间复杂度是 $O(n*\log n)$ 。它有几个待优化的点:

  1. 分割数挑选:极度依赖挑选的分割数。如果挑选的分割数不好,那么会造成算法性能急剧退化;
  2. 相等的元素处理:相等的元素如果不做处理的话,并不会影响快排的性能,但如果处理,会极大的提升性能;
  3. 小数组处理:小数组处理上,因为要分治、递归,所以快排的效率并不比插入排序快多少;
  4. 栈使用严重:因为是使用递归来进行排序,所以当数据集很大时,递归对于栈的消耗比较大;

上面几个点,是高效使用快速排序必须要解决的。

插入排序

插入排序非常依赖原始数据的顺序,最坏的时间复杂度可以达到 $O(n^2)$。所以,在使用插入排序的时候,不能是太大的数据量,其次不能是倒序。

希尔排序

希尔排序相当于插入排序的改进版,优化降低了对于源数据顺序的依赖。也是在小数据组更加适合。

堆排序

同样是 $O(n * \log n)$ 的时间复杂度,但是没有了递归,也就没有了栈的消耗。缺点是维护堆的有序性比较麻烦,并且在排序过程中元素的移动频繁,实际上复杂度中常量会比较大。这也是没能坐上排序龙头位置的原因吧。

Golang 中的排序怎么实现

上面我们总结了排序的各个优缺点。那么,在实际的过程中,怎么能够扬长避短写出一个高效的排序算法呢?我们来看看Golang标准库中的 sort 是怎么做的。

定义

这里有两个比较重要的定义:一是 Interface ,二是 reverse 。reverse 是 Interface 的一个包装,用来实现倒序排序。

Interface 是一个接口,定义了三个方法。分别用来获取数据集长度、比较两个位置的值,交换两个位置的值。这样的接口定义,也就是说,被排序的数据集必须是可以使用 index 索引的。不能是链表一类的数据结构,只能是数组,或者是实现了 index 索引的数据结构才能进行排序。

type Interface interface {
    // 数据集元素的数量
    Len() int
    // 如果是 true i 会排在 j 前面
    Less(i, j int) bool
    // 交换两个位置的值
    Swap(i, j int)
}

type reverse struct {
    // 一个 Interface 的包装
    Interface
}

// 反向实现了 Interface 中的 Less
func (r reverse) Less(i, j int) bool {
    return r.Interface.Less(j, i)
}

// 对于数据集 data 返回一个 reverse 的 Interface
func Reverse(data Interface) Interface {
    return &reverse{data}
}

各个排序算法的使用时间

因为四种排序各有优劣,所以,我们对什么时候使用什么样的排序也要有一个清晰的了解。

排序的入口是 Sort 方法。源码注释写道,这个算法是 $O(n * \log n)$ 的时间复杂度,并且是不稳定的。实际的排序是使用了 quickSort 这个快排的实现。传入四个参数:数据集、起始索引,终止索引,最大深度

// 排序入口
func Sort(data Interface) {
    n := data.Len()
    quickSort(data, 0, n, maxDepth(n))
}

最大深度是什么意思呢?这个是指递归调用的深度,也就是递归树的高度。我们知道,快速排序使用递归时,如果递归深度太深,那么就意味着对栈的要求极高,而栈的大小是有限的。一般来说,解决这个问题有两个办法:一是通过自建栈来保存中间态,二是限制递归深度。Golang 这里选择了第二个办法。

我们来看一下这个 maxDepth 的实现,没什么太特别要讲的:

// 2*ceil(lg(n+1))
func maxDepth(n int) int {
    var depth int
    for i := n; i > 0; i >>= 1 {
        depth++
    }
    return depth * 2
}

现在我们可以看一下各个算法的使用时间了:

  1. 当数据集元素 小等于 12 的时候
    1. 当 数据集 大于 6 的时候,先做一次数据间隔为 6 的 希尔排序
    2. 做一次 插入排序
  2. 当数据集元素 大于12 的时候
    1. 快速排序
    2. 当 快排递归深度大于 最大深度 时,转为 堆排序
func quickSort(data Interface, a, b, maxDepth int) {
    for b-a > 12 { // 大于 12 个元素,使用 快排
        if maxDepth == 0 {  // 快排达到最大深度,使用 堆排序
            heapSort(data, a, b)
            return
        }
        maxDepth--
        mlo, mhi := doPivot(data, a, b)
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }
    if b-a > 1 { // 元素 <= 12 个

        // 间隔为 6 的希尔排序
        for i := a + 6; i < b; i++ {
            if data.Less(i, i-6) {
                data.Swap(i, i-6)
            }
        }
        // 插入排序
        insertionSort(data, a, b)
    }
}

因为快速排序中,做了比较多的优化,我们放到后面再讲,先将其他三种排序说一下

希尔排序和插入排序

当 元素小等于 12 个的时候,我们使用的是 希尔排序和插入排序。

我们知道,插入排序是对元素顺序极其依赖的,这里先使用希尔排序就是为了让性能不至于退化到最差的情况。

    if b-a > 1 { // 元素 <= 12 个

        // 间隔为 6 的希尔排序
        // 从第6个元素开始,依次相后
        for i := a + 6; i < b; i++ {
            // 同 6个之前的元素相比,
            if data.Less(i, i-6) {
                // 将小的元素放到前面
                data.Swap(i, i-6)
            }
        }
        // 插入排序
        insertionSort(data, a, b)
    }
// 插入排序
func insertionSort(data Interface, a, b int) {
    // 从第1个元素开始,依次相后处理
    for i := a + 1; i < b; i++ {
        // 依次与自己之前的元素比较,将元素放到合适的位置
        for j := i; j > a && data.Less(j, j-1); j-- {
            data.Swap(j, j-1)
        }
    }
}

堆排序

当递归深度,大于最大的深度时,就会抛弃快排,转而使用堆排序。这里的堆是大顶堆,排序的时候依次弹出元素从后向前放到数组中,达到正序排序的效果。

func heapSort(data Interface, a, b int) {
    
    // lo 和 hi 都是相对位置,first 相当于相对位置的锚点
    first := a
    lo := 0
    hi := b - a

    // 从最后一个满叶子的节点开始构建大顶堆
    for i := (hi - 1) / 2; i >= 0; i-- {
        // 当前节点向下堆化
        siftDown(data, i, hi, first)
    }

    // 此时,在给定区间里,堆已经做好,我们要做的就是
    // 1. i 等于最后一个元素
    // 2. 将堆顶(first)和 i 互换
    // 3. i --
    // 4. 将新的 first 元素向下堆化
    // 5. 重复 i ,直到 i == 0
    for i := hi - 1; i >= 0; i-- {
        data.Swap(first, first+i)
        siftDown(data, lo, i, first)
    }
}

// 向下堆化
func siftDown(data Interface, lo, hi, first int) {
    root := lo  // 当前根节点
    for {
        child := 2*root + 1  // 左节点
        if child >= hi { // 没有左节点,退出
            break
        }
        // 有右节点,左节点小于右节点
        if child+1 < hi && data.Less(first+child, first+child+1) {
            // child 变为 右节点
            child++
        }
        
        // 如果当前根节点比 左右节点 都大,那么完成堆化
        if !data.Less(first+root, first+child) {
            return
        }
        
        // 交换 根节点和子节点
        data.Swap(first+root, first+child)
        
        // 子节点作为新的根节点,重复上面动作
        root = child
    }
}

这里需要注意的是,我们使用大顶堆,并通过交换堆顶和最后一个叶子节点,然后右边界左移一个,这种形式来达成正序排序的效果。

堆排序没有做过多难以理解的优化,基本遵循平常的堆排序原理,不做过多讲解了。

快速排序

我们回顾一下上面对于快速排序的理解。他有4个缺点,现在我们通过其他三种算法将3、4两个缺点解决掉了。那么剩下 分割数挑选相等元素处理 亟待解决。

那么,去掉其他几种算法的使用,现在整个快速排序就变成了下面这样:

func quickSort(data Interface, a, b, maxDepth int) {
    for b-a > 12 { 
        
        mlo, mhi := doPivot(data, a, b)
        
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }
}

这里面有一个非常重要的函数 doPivot 。这个函数其实就是快速排序的核心了,它两件事情:

  1. 找到一个尽量合适的分割数;
  2. 将分割数两边的数据尽量按照正态分布存放;
  3. 处理有很多相同数据的情况;

个人觉得,整个 doPivot 可以说被 a、b、c 三个变量毁了。如果不是弄懂作者的意思,光看这三个变量,看得头大。简直就是优雅编程的反例。

我们一点点来讲。我们将整个 doPivot 函数分成三个部分:

  1. 第一部分:寻找分割数;
  2. 第二部分:将数据按照选好的分割数,按大小,做正态分布处理;
  3. 第三部分:处理大量相同元素的情况;

第一部分

寻找分割数(pivot)是一个比较麻烦的事情。理论上来说,当然是中位数最好,但实际情况是不可能在大量数据中去找出中位数的。所以我们尽量找一个接近中位数的方法。

在这里有两种情况,一是当数据大于 40 个时,会使用 Tukey’s Ninther 这种方法来寻找 pivot。如果小于40个,那么就直接在三个位置中寻找中位数。

Tukey's Ninther 这个方法,其实就是将整个数据集分为三段,然后每一段都在开始、中间、结束三个位置中,找到一个中位数。最后,在这三个中位数中间,再找一个中位数,就认为是最接近原始数据集的中位数的值了。

golang-sort

上面是一段无序的数组,我们使用 Tukey's Ninther 这种方法来查找,可以清晰的看到结果是 9。很接近中位数8。而如果不用这种方法的话,则结果是 13,差得比较远了。

❗40 这个魔法数字,目前并没有找到相关的解释,或者算法的论证。

medianOfThree 这个函数很简单,不展开了,就是给定三个位置,$m1,m0,m2$(注意m0在中间),然后排序,最后 $m0 \leq m1 \leq m2$ 。所以,最后中位数的位置就在 $m1$ 上。

我们来看代码:

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    m := int(uint(lo+hi) >> 1)  // 找到中间位置
    if hi-lo > 40 { // 当大于 40 个元素时,使用 Tukey's Ninther
        // 将整个数据集分为8段,然后在 前两段、中间两段、后两段
        // 中找到三个分割数。
        // 然后,在找出来的三个数中,找中位数。
        s := (hi - lo) / 8
        medianOfThree(data, lo, lo+s, lo+2*s)
        medianOfThree(data, m, m-s, m+s)
        medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
    }
    medianOfThree(data, lo, m, hi-1)
    ……
}

最后的结果就是: pivot 在放在 lo 这个位置上。

第二部分

接下来是第二部分,分割数据。其实就是将整个数据集,按照选定的 pivot,切成几段,方便后面处理。

源码中,注释说得比较清楚:

首先,lo 的位置上是 pivot。实际处理这个过程是定义了三个变量 abca/ b 两个变量从前向后递增处理,c 从后向前递减处理;

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    …………

    pivot := lo
    a, c := lo+1, hi-1 // a 从 lo+1 开始,c 从 hi-1 开始

    // 从前向后,找到一个 data[a] >= pivot 的位置
    for ; a < c && data.Less(a, pivot); a++ {
    }
    b := a // b 从这里开始找
    for { // 这里 b 和 c 要一直处理,直到两者相交
        // b 从前向后找,找到一个 pivot >= data[b] 的位置
        for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
        }
        // c 从后向前找,找到一个 pivot < data[c-1] 的位置
        for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
        }
        // 如果 b 和 c 相交,则结束
        if b >= c {
            break
        }
        // 交换 b 和 c-1 的值
        data.Swap(b, c-1)
        // 向前、向后走一步
        b++
        c--
    }
    …………
}

到上面这部分处理完。我们已经完成了数据集的一次处理,得到的结果是:

  1. $(lo,a)$ 所有元素 小于 pviot
  2. $[a,b)$ 所有元素 小等于 pviot
  3. $[c,h-1)$ 所有元素 大于 pviot

因为最后退出循环的标志是 b >= c ,那么就意味着,b 和 c 之间是没有元素的。至此,我们将所有的元素基本分为了几个部分。接下来的处理就是处理如果有多个和 pviot 重复的数据。

下面的代码主要1个变量值得关注 protect。它表示当前状态下,是否需要处理大量重复的 pviot

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    ………… // 省略上面讲过的逻辑

    protect := hi-c < 5 // 当比 pviot 大的值数量小于 5 时,表示有大量重复数据。
    // 如果 protect = false 而且 比 pviot 大的值数量小于总集合的 1/4
    // 那就 检查几个点看看是否有重复数据
    if !protect && hi-c < (hi-lo)/4 { 
        dups := 0
        
        // 如果 data[hi-1] 和 pviot 相等
        if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
            data.Swap(c, hi-1)       // 交换和 c 的位置
            c++     // 将 c 的位置 向后移动
            dups++  // 表示有重复数据
        }
        
        // 如果 data[b-1] = pivot 相等
        if !data.Less(b-1, pivot) { // data[b-1] = pivot
            b--    // b 向前移动位置
            dups++ // 表示有重复数据
        }

        // 这里对比位置 m 和 pivot 的数值
        // m 是数据集的中间位置,而只有大于12个元素的时候,才能到快排中
        // 所以 m-lo = (hi-lo)/2 > 6
        // 下面这句是表示 m 的位置是在 b-lo 里面的 也就是 <= pivot 这个区间
        // b-lo > (hi-lo)*3/4-1 > 8 
        // 如此,那么下面的条件如果成立,说明 data[m] = pivot
        if !data.Less(m, pivot) { // data[m] = pivot
            data.Swap(m, b-1)
            b--
            dups++
        }
        // if at least 2 points are equal to pivot, assume skewed distribution
        // 如果有两个以上的重复点,那么认为当前的数据分布是偏态,这样会导致快排的性能下降
        // 此时就必须处理重复数据
        protect = dups > 1
    }
    if protect { // 如果需要处理重复数据
        // Protect against a lot of duplicates
        // Add invariant:
        //    data[a <= i < b] unexamined
        //    data[b <= i < c] = pivot
        for {
            // b 向前查找,如果 data[b-1] == pivot 那么,位置向前移
            for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
            }
            // a 向前查找,如果 data[a] < pivot 那么a向前移动
            for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
            }
            // a b 相交 
            if a >= b {
                break
            }
            // 交换 a 和 b-1 的位置
            data.Swap(a, b-1)
            a++
            b--
        }
    }
    // 把 pivot 放到中间
    data.Swap(pivot, b-1)
    return b - 1, c // 返回 pivot 和 第一个比 pivot 大的元素的位置
}

上面的处理完之后,如果重复元素很多的话,那么整个数据集被分为:

  1. $(lo,a)$ 所有元素 小于 pviot这里 a 和 b 相交了
  2. $[b,c-1)$ 所有元素 等于 pviot
  3. $[c,h-1)$ 所有元素 大于 pviot

那么最后返回的是 相等元素的两端。相等元素则不再参与排序。

再回到快排这里,来看看后续的处理。整个快排是循环+递归的形式组成的。在一次循环里,只处理分割中较少的一半,这样可以减少递归调用的深度,控制栈的使用。

func quickSort(data Interface, a, b, maxDepth int) {
    for b-a > 12 { // Use ShellSort for slices <= 12 elements
        maxDepth--
        // 通过 doPivot 拿到了 两个值的位置
        mlo, mhi := doPivot(data, a, b)
        
        // 这里看一下前半段比较短还是后半段比较短,只处理短的部分。这样可以控制栈的深度
        // 长的一部分会在下一次的循环中被处理掉
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }
}

到此,整个的快排就全部结束了。

尾声

Golang 中的排序算法用了很多种优化办法,的确是将曾经学过的排序算法中优化的地方全部优化到了。算是一个理论结合实践最紧密的一个标准库了。非常值得学习。


署名-非商业性使用-相同方式共享 4.0