回溯法中的 DP 陷阱
今天在做 leetcode 每日一题 140. 单词拆分 II,遇到一个很有意思的问题。使用同样的转移方程,自顶向下可以通过,自底向上却超时。自顶向下的过程中实际会产生更多的剪枝,从而减少计算量。
单词拆分 II:
给定一个非空字符串 s 和一个包含非空单词列表的字典
wordDict
,在字符串中增加空格来构建一个句子,使得句子中所有的单词都在词典中。返回所有这些可能的句子。分隔时可以重复使用字典中的单词。你可以假设字典中没有重复的单词。
输入:
s = "catsanddog" wordDict = ["cat", "cats", "and", "sand", "dog"]
输出:
[ "cats and dog", "cat sand dog" ]
题解思路
这道题其实挺简单的,就是一个动态规划的思路,子状态 的解可以从之前的解 中得到,只需要判断子串 是否是 words
中的单词即可。对于 a{n}
这样的构造,复杂度非常炸裂,可以达到 。
我刚开始使用的是动态规划(自底向上)写法,利用 solutions
数组缓存每个状态 的解。
class Solution:
def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
n = len(s)
words = set(wordDict)
# solutions[i] = [0, i) 解法
solutions = [
[] for _ in range(n+1)
]
solutions[0] = [[]]
for i in range(1, n+1):
for start in range(0, i):
possible_word = s[start:i]
if possible_word in words:
for prev_solution in solutions[start]:
solutions[i].append(prev_solution + [possible_word])
return [' '.join(word) for word in solutions[n]]
遇到的问题
这种写法对这样一个输入超时了:
s = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
wordDict = ["a", "aa", "aaa", "aaaa", "aaaaa", "aaaaaa", "aaaaaaa", "aaaaaaaa", "aaaaaaaaa", "aaaaaaaaaa"]
注意到,s
当中有一个 b
字符,因此这个输入的解是空集。
官方的写法没使用 DP,而是使用 DFS(自顶向下):
class Solution:
def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
@lru_cache(None)
def backtrack(index: int) -> List[List[str]]:
# [index, n) 解法
if index == len(s):
return [[]]
ans = list()
for i in range(index + 1, len(s) + 1):
word = s[index:i]
if word in wordSet:
nextWordBreaks = backtrack(i)
for nextWordBreak in nextWordBreaks:
ans.append(nextWordBreak.copy() + [word])
print(f'计算 [{index},n) = {len(ans)}')
return ans
wordSet = set(wordDict)
breakList = backtrack(0)
return [" ".join(words[::-1]) for words in breakList]
乍一眼看上去,这两种方法完全一样,使用 DP 还能节省一点栈空间,为什么 DP 会炸呢?
问题分析
原因就在于这两种方法一个是 自顶向下,一个是 自底向上,在处理 aaaaabaaaaaaa
这样的数据中,我们看看这两种解法会如何处理:
自顶向下(DFS)
稍微修改一下,在返回之前打印一点信息:
计算 [75,n) = 0
计算 [74,n) = 0
计算 [73,n) = 0
计算 [72,n) = 0
...
注意到 正是字符 b
所在的位置,因此在对所有可能的字符都判定完不存在单词列表中之后,这个解法会直接返回空集,使得之前的所有解法都快速返回空集,对这种数据的复杂度为 。
自底向上(DP)
同样在计算完成之后打印信息:
计算 [0, 1) = 1
计算 [0, 2) = 2
计算 [0, 3) = 4
计算 [0, 4) = 8
计算 [0, 5) = 16
计算 [0, 6) = 32
计算 [0, 7) = 64
计算 [0, 8) = 128
计算 [0, 9) = 256
计算 [0, 10) = 512
计算 [0, 11) = 1023
计算 [0, 12) = 2045
...
由于我试用的是自顶向上的写法,因此会优先对前面的子串(a{n}
部分)进行求解,而这种计算是 的,因此会直接超时。
这里,两种方法的区别在于,某些情况下,计算 并不需要 。因此,采用自顶向下的递归时,会发生剪枝,从而减少大量计算量。而注意到对于特殊的输入(a{n}
类型),减少的计算量是 级别的。
问题解决
将算法修改为自顶向下即可,
class Solution:
def wordBreak(self, s: str, wordDict: List[str]) -> List[str]:
n = len(s)
words = set(wordDict)
# cached, solutions[i] = [0, i) 解法
solutions = [
None for _ in range(n+1)
]
solutions[0] = [[]]
def dfs(i): # i in [1, n], 判断 [0, i) 子串
if solutions[i] is not None:
return solutions[i]
ans = []
for start in range(0, i): # 判断 [start, i) 是否是单词
possible_word = s[start:i]
if possible_word in words:
prev_solutions = dfs(start)
for prev_solution in prev_solutions:
ans.append(prev_solution + [possible_word])
solutions[i] = ans
return ans
dfs(n)
return [' '.join(words) for words in solutions[n]]