KMP 算法学习笔记

本文是个人对 KMP 算法关键点的总结,以备复习,会省略容易理解的内容,只记录个人不熟悉的知识点。

基本思路

字符串匹配问题就是查询字符串 a 中是否包含了字符串 b,其中 a 为目标字符串(简称目标串),b 为模式字符串(简称模式串)。

简单总结:KMP 算法高效的原因在于可以对暴力扫描法进行高效剪枝,而剪枝的策略是扫描模式串中是否有重复的子串,省略对重复子串的扫描。 详细解释如下:

  1. 构建递推情景:如果字符串 a 与 b 直到第 n 个字符才发生不匹配情况,那么前 n-1 个字符必定是匹配的
  2. 去重复预想:如果前 n-1 个字符中有 k 个重复字符,那么也许可以省略部分重复的扫描工作
  3. 简化去重复:能够简单、高效处理的情况是,前缀 k 个字符和后缀 k 个字符重复(最长公共前后缀),这样就直接从第 k+1 个字符开始重新扫描
  4. 推广简化状态:因为可以令长度 n 冲 0 一直累加迭代,则必定可以处理 b 中所有从头开始的子串的公共前后缀,即尽可能的省略所有重复扫描

更详细地分析 KMP 算法,能够看出算法的关键在于对于模式串的预处理,找到每一段从头开始的子串的最长公共前后缀,目标串反而不太重要。这样,在使用一个或少量固定的模式串匹配大量目标串时,预处理工作又能省略很多重复劳动。

伪码构建

不难看出,目标串的扫描操作非常简单,就是进行一次遍历,遇到不匹配的字符进行对应的下标调整操作。而每一个模式串的字符下标都肯发生调整操作,那么不妨构造一个和模式串等长的序列,记录每个模式串字符对应的子串的最长公共前后缀长度——一旦发生不匹配,就会省略这么长的前缀的扫描。 第一个字符的最长公共前后缀长度记为 -1,表示如果连第一个字符都不匹配,那么只好在目标串的下一个字符进行完全重启的扫描工作。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
fun kmpMatch(txt: String, pattern: String) {
    next: IntArray = nextArray(pattern)
    while (i < txt.length && j < pattern.length) {
        if (j >= 0 && txt[i] != pattern[j]) {
            j = next[j]
        } else {
            i++
            j++
        }
    }
    if (j == pattern.length)
        return i - j
    else
        return -1
}

那么,最关键的操作就是找出所有的最长公共前后缀,即所谓的NEXT 数组。构建 NEXT 数组公有两种办法,一种直观但低效,一种稍显抽象但高效。

直接按照定义,进行模拟操作,得到伪码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
fun nextArray(pattern: String) {
    next = IntArray(pattern.length)
    next[0] = -1
    for (i in [1..pattern.length]) {
        j = 0
        while (j < i && pattern[j] == pattern[i - j]) j++
        next[i] = j
    }
    return next
}

这样的办法在查找每一个字符对应的长度时,都会对它前面的字符进行遍历,实际上又发生了重复扫描,所以效率不高。 不过,它比较适合人工实行 KMP 算法的操作,主要用于笔试考试和构造简单测试数据。

实际上,构造 NEXT 数组时,又像是进行了字符串的匹配,只不过新的目标串是旧的模式串的某个子串的前缀,新的模式串是旧的模式串的后缀。那么,我们同样可以使用递推的方法拆解问题,分析它有何子状态。

假设,要查找下标 k 对应的最长公共前后缀长度,那么前 k-1 个长度已经求出,k-1 下标的最长公共前后缀长度是 next[k-1],记为 t,即:

a. k-1 下标对应的公共前后缀是 pattern[0..t-1] 和 pattern[k-1-t..k-2],它们完全相同。

那么,如果 pattern[k-1] 和 pattern[t] 相等,则很容易得出 next[k] 就是 t+1。

如果不相等呢?我们必定是要在小于 t 长度的前后缀中,找到更短的最长公共前后缀,不妨设它为 x,即:

b. pattern[0..x] 和 pattern[k-1-x..k-2] 相同(x < t)

此时,根据性质 a 我们可以得出:

c. pattern[k-1-x..k-2] 和 pattern[t-x..t-1] 相同

那么,也就一定有:

d. pattern[0..x] 和 pattern[t-x..t-1] 相同

所以,记 next[t] 为 r,一定有:

e. x <= r,当且仅当 pattern[k-1] 和 pattern[r]相等时 x 和 r 相等

这样,就可以得出一个迭代关系,只要 pattern[k-1] 和 pattern[t] 不相等,就可以使用 next[t] 来作新的 t 的值,直到满足相等关系或者 next[t] 到达初始值 -1,最后 next[k] 的值就是 t + 1,最终的伪码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
fun nextArray(pattern: String) {
    next = IntArray(pattern.length)
    next[0] = -1
    for (i in [1..pattern.length]) {
        t = next[i - 1]
        while (t >= 0 && pattern[i - 1] != pattern[t]) {
            t = next[t]
        }
        next[i] = t + 1
    }
    return next
}

上述的优化算法仍然有个小瑕疵:设 next[k] 为 x,k 处不匹配时我们会从下标为 x 处开始扫描。此时,如果恰好有 pattern[k] 和 pattern[x] 相等,则必定 x 处一定也不匹配,然后又要从 x 处找寻 next[x] 来确定下一次扫描的起点,实际上我们应该直接从 next[x] 处开始重新扫描。 特别是,有形如 aaaaaaaaa...b 这样的模式串的时候,上面的优化方法会导致匹配操作多次进行不必要的重定位,反而是模拟法构建的 NEXT 数组更准确(毕竟模拟法完全按照定义而来)。 进一步优化的办法也很简单,伪代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
fun nextArray(pattern: String) {
    next = IntArray(pattern.length)
    next[0] = -1
    for (i in [1..pattern.length]) {
        t = next[i - 1]
        while (t >= 0 && pattern[i - 1] != pattern[t]) {
            t = next[t]
        }
        next[i] = t + 1
        if (next[i] != -1 && pattern[i] == pattern[next[i]])  {
            next[i] = next[next[i]]
        }
    }
    return next

实际代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
fun kmp(text: String, pattern: String): Int {
    // 提前去除不必要的匹配
    if (text.length < pattern.length) return -1
    if (text.length == pattern.length) return if (text == pattern) 0 else -1

    val next = nextArray(pattern)

    var i = 0
    var j = 0

    while (i < text.length && j < pattern.length) {
        if (j == -1 || text[i] == pattern[j]) {
            i++
            j++
        } else
            j = next[j]
    }

    // 如果只需要判断是否包含模式串,则直接返回布尔类型 j == pattern.length 即可
    return if (j == pattern.length) i - j else -1
}

private fun nextArray(str: String): IntArray {
    if (str.isEmpty()) return IntArray(0)

    val next = IntArray(str.length)
    next[0] = -1
    for (i in 1..str.lastIndex) {
        var t = next[i - 1]
        while (t != -1 && str[i - 1] != str[t]) t = next[t]
        next[i] = t + 1
        // 对 aaaaaab 形式的模式串 NEXT 数组进行特别优化
        if (next[i] != -1 && str[next[i]] == str[i]) next[i] = next[next[i]]
    }

    return next
}