回溯法中的 DP 陷阱

Published 11/1/2020
Views 8

今天在做 leetcode 每日一题 140. 单词拆分 II,遇到一个很有意思的问题。使用同样的转移方程,自顶向下可以通过,自底向上却超时。自顶向下的过程中实际会产生更多的剪枝,从而减少计算量。

单词拆分 II:

给定一个非空字符串 s 和一个包含非空单词列表的字典 wordDict,在字符串中增加空格来构建一个句子,使得句子中所有的单词都在词典中。返回所有这些可能的句子。

分隔时可以重复使用字典中的单词。你可以假设字典中没有重复的单词。

输入:

s = "catsanddog"
wordDict = ["cat", "cats", "and", "sand", "dog"]

输出:

[
"cats and dog",
"cat sand dog"
]

题解思路

这道题其实挺简单的,就是一个动态规划的思路,子状态 的解可以从之前的解 中得到,只需要判断子串 是否是 words 中的单词即可。对于 a{n} 这样的构造,复杂度非常炸裂,可以达到

我刚开始使用的是动态规划(自底向上)写法,利用 solutions 数组缓存每个状态 的解。

class Solution:
    def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
        n = len(s)
        words = set(wordDict)
        # solutions[i] = [0, i) 解法
        solutions = [
            [] for _ in range(n+1)
        ]
        solutions[0] = [[]]

        for i in range(1, n+1):
            for start in range(0, i):
                possible_word = s[start:i]
                if possible_word in words:
                    for prev_solution in solutions[start]:
                        solutions[i].append(prev_solution + [possible_word])
        return [' '.join(word) for word in solutions[n]]

遇到的问题

这种写法对这样一个输入超时了:

s = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
wordDict = ["a", "aa", "aaa", "aaaa", "aaaaa", "aaaaaa", "aaaaaaa", "aaaaaaaa", "aaaaaaaaa", "aaaaaaaaaa"]

注意到,s 当中有一个 b 字符,因此这个输入的解是空集。

官方的写法没使用 DP,而是使用 DFS(自顶向下):

class Solution:
    def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
        @lru_cache(None)
        def backtrack(index: int) -> List[List[str]]:
            # [index, n) 解法
            if index == len(s):
                return [[]]
            ans = list()
            for i in range(index + 1, len(s) + 1):
                word = s[index:i]
                if word in wordSet:
                    nextWordBreaks = backtrack(i)
                    for nextWordBreak in nextWordBreaks:
                        ans.append(nextWordBreak.copy() + [word])
            print(f'计算 [{index},n) = {len(ans)}')
            return ans

        wordSet = set(wordDict)
        breakList = backtrack(0)
        return [" ".join(words[::-1]) for words in breakList]

乍一眼看上去,这两种方法完全一样,使用 DP 还能节省一点栈空间,为什么 DP 会炸呢?

问题分析

原因就在于这两种方法一个是 自顶向下,一个是 自底向上,在处理 aaaaabaaaaaaa 这样的数据中,我们看看这两种解法会如何处理:

自顶向下(DFS)

稍微修改一下,在返回之前打印一点信息:

计算 [75,n) = 0
计算 [74,n) = 0
计算 [73,n) = 0
计算 [72,n) = 0
...

注意到 正是字符 b 所在的位置,因此在对所有可能的字符都判定完不存在单词列表中之后,这个解法会直接返回空集,使得之前的所有解法都快速返回空集,对这种数据的复杂度为

自底向上(DP)

同样在计算完成之后打印信息:

计算 [0, 1) = 1
计算 [0, 2) = 2     
计算 [0, 3) = 4     
计算 [0, 4) = 8     
计算 [0, 5) = 16    
计算 [0, 6) = 32    
计算 [0, 7) = 64    
计算 [0, 8) = 128   
计算 [0, 9) = 256   
计算 [0, 10) = 512  
计算 [0, 11) = 1023 
计算 [0, 12) = 2045 
...

由于我试用的是自顶向上的写法,因此会优先对前面的子串(a{n}部分)进行求解,而这种计算是 的,因此会直接超时。

这里,两种方法的区别在于,某些情况下,计算 并不需要 。因此,采用自顶向下的递归时,会发生剪枝,从而减少大量计算量。而注意到对于特殊的输入(a{n}类型),减少的计算量是 级别的。

问题解决

将算法修改为自顶向下即可,

class Solution:
    def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
        n = len(s)
        words = set(wordDict)
        # cached, solutions[i] = [0, i) 解法
        solutions = [
            None for _ in range(n+1)
        ]
        solutions[0] = [[]]

        def dfs(i):  # i in [1, n], 判断 [0, i) 子串
            if solutions[i] is not None:
                return solutions[i]
            ans = []

            for start in range(0, i):  # 判断 [start, i) 是否是单词
                possible_word = s[start:i]
                if possible_word in words:
                    prev_solutions = dfs(start)
                    for prev_solution in prev_solutions:
                        ans.append(prev_solution + [possible_word])

            solutions[i] = ans
            return ans

        dfs(n)
        return [' '.join(words) for words in solutions[n]]
0 comments