From f1102fb2621d26efe722b58ffc12959086f0e275 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 23 Jan 2022 23:37:02 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20optimize=20string=20search=20with=20Ah?= =?UTF-8?q?o=E2=80=93Corasick=20algorithm=20(#1476)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: optimize string search with Aho–Corasick algorithm * chore: optimize keywords replacer * fix: replacer bugs * chore: reorder members --- core/stringx/node.go | 77 +++++++++++++++++++-- core/stringx/node_test.go | 25 +++++++ core/stringx/replacer.go | 85 ++++++++++++++--------- core/stringx/replacer_fuzz_test.go | 42 ++++++++++++ core/stringx/replacer_test.go | 104 +++++++++++++++++++++++++++++ core/stringx/trie.go | 48 ++----------- core/stringx/trie_test.go | 34 ++++------ core/syncx/singleflight.go | 10 --- 8 files changed, 316 insertions(+), 109 deletions(-) create mode 100644 core/stringx/node_test.go create mode 100644 core/stringx/replacer_fuzz_test.go diff --git a/core/stringx/node.go b/core/stringx/node.go index 39d7a6d2..e11eb686 100644 --- a/core/stringx/node.go +++ b/core/stringx/node.go @@ -2,6 +2,8 @@ package stringx type node struct { children map[rune]*node + fail *node + depth int end bool } @@ -12,17 +14,19 @@ func (n *node) add(word string) { } nd := n - for _, char := range chars { + var depth int + for i, char := range chars { if nd.children == nil { child := new(node) - nd.children = map[rune]*node{ - char: child, - } + child.depth = i + 1 + nd.children = map[rune]*node{char: child} nd = child } else if child, ok := nd.children[char]; ok { nd = child + depth++ } else { child := new(node) + child.depth = i + 1 nd.children[char] = child nd = child } @@ -30,3 +34,68 @@ func (n *node) add(word string) { nd.end = true } + +func (n *node) build() { + n.fail = n + for _, child := range n.children { + child.fail = n + n.buildNode(child) + } +} + +func (n *node) buildNode(nd *node) { + if nd.children == nil { + return + } + + var fifo []*node + for key, child := range nd.children { + fifo = append(fifo, child) + + if fail, ok := nd.fail.children[key]; ok { + child.fail = fail + } else { + child.fail = n + } + } + + for _, val := range fifo { + n.buildNode(val) + } +} + +func (n *node) find(chars []rune) []scope { + var scopes []scope + size := len(chars) + cur := n + + for i := 0; i < size; i++ { + child, ok := cur.children[chars[i]] + if ok { + cur = child + } else if cur == n { + continue + } else { + cur = cur.fail + if child, ok = cur.children[chars[i]]; !ok { + continue + } + cur = child + } + + if child.end { + scopes = append(scopes, scope{ + start: i + 1 - child.depth, + stop: i + 1, + }) + } + if child.fail != n && child.fail.end { + scopes = append(scopes, scope{ + start: i + 1 - child.fail.depth, + stop: i + 1, + }) + } + } + + return scopes +} diff --git a/core/stringx/node_test.go b/core/stringx/node_test.go new file mode 100644 index 00000000..c3dcba9b --- /dev/null +++ b/core/stringx/node_test.go @@ -0,0 +1,25 @@ +package stringx + +import "testing" + +func BenchmarkNodeFind(b *testing.B) { + b.ReportAllocs() + + keywords := []string{ + "A", + "AV", + "AV演员", + "无名氏", + "AV演员色情", + "日本AV女优", + } + trie := new(node) + for _, keyword := range keywords { + trie.add(keyword) + } + trie.build() + + for i := 0; i < b.N; i++ { + trie.find([]rune("日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")) + } +} diff --git a/core/stringx/replacer.go b/core/stringx/replacer.go index 00c93195..947ae69a 100644 --- a/core/stringx/replacer.go +++ b/core/stringx/replacer.go @@ -9,7 +9,7 @@ type ( } replacer struct { - node + *node mapping map[string]string } ) @@ -17,58 +17,81 @@ type ( // NewReplacer returns a Replacer. func NewReplacer(mapping map[string]string) Replacer { rep := &replacer{ + node: new(node), mapping: mapping, } for k := range mapping { rep.add(k) } + rep.build() return rep } +// Replace replaces text with given substitutes. func (r *replacer) Replace(text string) string { var builder strings.Builder + var start int chars := []rune(text) size := len(chars) - start := -1 - for i := 0; i < size; i++ { - child, ok := r.children[chars[i]] - if !ok { - builder.WriteRune(chars[i]) - continue + for start < size { + cur := r.node + + if start > 0 { + builder.WriteString(string(chars[:start])) } - if start < 0 { - start = i - } - end := -1 - if child.end { - end = i + 1 - } + for i := start; i < size; i++ { + child, ok := cur.children[chars[i]] + if ok { + cur = child + } else if cur == r.node { + builder.WriteRune(chars[i]) + // cur already points to root, set start only + start = i + 1 + continue + } else { + curDepth := cur.depth + cur = cur.fail + child, ok = cur.children[chars[i]] + if !ok { + // write this path + builder.WriteString(string(chars[i-curDepth : i+1])) + // go to root + cur = r.node + start = i + 1 + continue + } - j := i + 1 - for ; j < size; j++ { - grandchild, ok := child.children[chars[j]] - if !ok { + failDepth := cur.depth + // write path before jump + builder.WriteString(string(chars[start : start+curDepth-failDepth])) + start += curDepth - failDepth + cur = child + } + + if cur.end { + val := string(chars[i+1-cur.depth : i+1]) + builder.WriteString(r.mapping[val]) + builder.WriteString(string(chars[i+1:])) + // only matching this path, all previous paths are done + if start >= i+1-cur.depth && i+1 >= size { + return builder.String() + } + + chars = []rune(builder.String()) + size = len(chars) + builder.Reset() break } - - child = grandchild - if child.end { - end = j + 1 - i = j - } } - if end > 0 { - i = j - 1 - builder.WriteString(r.mapping[string(chars[start:end])]) - } else { - builder.WriteRune(chars[i]) + if !cur.end { + builder.WriteString(string(chars[start:])) + return builder.String() } - start = -1 } - return builder.String() + return string(chars) } diff --git a/core/stringx/replacer_fuzz_test.go b/core/stringx/replacer_fuzz_test.go new file mode 100644 index 00000000..2e9facbe --- /dev/null +++ b/core/stringx/replacer_fuzz_test.go @@ -0,0 +1,42 @@ +//go:build go1.18 +// +build go1.18 + +package stringx + +import ( + "fmt" + "math/rand" + "strings" + "testing" +) + +func FuzzReplacerReplace(f *testing.F) { + keywords := make(map[string]string) + for i := 0; i < 20; i++ { + keywords[Randn(rand.Intn(10)+5)] = Randn(rand.Intn(5) + 1) + } + rep := NewReplacer(keywords) + printableKeywords := func() string { + var buf strings.Builder + for k, v := range keywords { + fmt.Fprintf(&buf, "%q: %q,\n", k, v) + } + return buf.String() + } + + f.Add(50) + f.Fuzz(func(t *testing.T, n int) { + text := Randn(rand.Intn(n%50+50) + 1) + defer func() { + if r := recover(); r != nil { + t.Errorf("mapping: %s\ntext: %s", printableKeywords(), text) + } + }() + val := rep.Replace(text) + keys := rep.(*replacer).node.find([]rune(val)) + if len(keys) > 0 { + t.Errorf("mapping: %s\ntext: %s\nresult: %s\nmatch: %v", + printableKeywords(), text, val, keys) + } + }) +} diff --git a/core/stringx/replacer_test.go b/core/stringx/replacer_test.go index 89e3a1ec..08c2661f 100644 --- a/core/stringx/replacer_test.go +++ b/core/stringx/replacer_test.go @@ -15,6 +15,14 @@ func TestReplacer_Replace(t *testing.T) { assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五")) } +func TestReplacer_ReplaceOverlap(t *testing.T) { + mapping := map[string]string{ + "3d": "34", + "bc": "23", + } + assert.Equal(t, "a234e", NewReplacer(mapping).Replace("abcde")) +} + func TestReplacer_ReplaceSingleChar(t *testing.T) { mapping := map[string]string{ "二": "2", @@ -42,3 +50,99 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) { } assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五")) } + +func TestReplacer_ReplaceJumpToFail(t *testing.T) { + mapping := map[string]string{ + "bcdf": "1235", + "cde": "234", + } + assert.Equal(t, "ab234fg", NewReplacer(mapping).Replace("abcdefg")) +} + +func TestReplacer_ReplaceJumpToFailDup(t *testing.T) { + mapping := map[string]string{ + "bcdf": "1235", + "ccde": "2234", + } + assert.Equal(t, "ab2234fg", NewReplacer(mapping).Replace("abccdefg")) +} + +func TestReplacer_ReplaceJumpToFailEnding(t *testing.T) { + mapping := map[string]string{ + "bcdf": "1235", + "cdef": "2345", + } + assert.Equal(t, "ab2345", NewReplacer(mapping).Replace("abcdef")) +} + +func TestReplacer_ReplaceEmpty(t *testing.T) { + mapping := map[string]string{ + "bcdf": "1235", + "cdef": "2345", + } + assert.Equal(t, "", NewReplacer(mapping).Replace("")) +} + +func TestFuzzCase1(t *testing.T) { + keywords := map[string]string{ + "yQyJykiqoh": "xw", + "tgN70z": "Q2P", + "tXKhEn": "w1G8", + "5nfOW1XZO": "GN", + "f4Ov9i9nHD": "cT", + "1ov9Q": "Y", + "7IrC9n": "400i", + "JQLxonpHkOjv": "XI", + "DyHQ3c7": "Ygxux", + "ffyqJi": "u", + "UHuvXrbD8pni": "dN", + "LIDzNbUlTX": "g", + "yN9WZh2rkc8Q": "3U", + "Vhk11rz8CObceC": "jf", + "R0Rt4H2qChUQf": "7U5M", + "MGQzzPCVKjV9": "yYz", + "B5jUUl0u1XOY": "l4PZ", + "pdvp2qfLgG8X": "BM562", + "ZKl9qdApXJ2": "T", + "37jnugkSevU66": "aOHFX", + } + rep := NewReplacer(keywords) + text := "yjF8fyqJiiqrczOCVyoYbLvrMpnkj" + val := rep.Replace(text) + keys := rep.(*replacer).node.find([]rune(val)) + if len(keys) > 0 { + t.Errorf("result: %s, match: %v", val, keys) + } +} + +func TestFuzzCase2(t *testing.T) { + keywords := map[string]string{ + "dmv2SGZvq9Yz": "TE", + "rCL5DRI9uFP8": "hvsc8", + "7pSA2jaomgg": "v", + "kWSQvjVOIAxR": "Oje", + "hgU5bYYkD3r6": "qCXu", + "0eh6uI": "MMlt", + "3USZSl85EKeMzw": "Pc", + "JONmQSuXa": "dX", + "EO1WIF": "G", + "uUmFJGVmacjF": "1N", + "DHpw7": "M", + "NYB2bm": "CPya", + "9FiNvBAHHNku5": "7FlDE", + "tJi3I4WxcY": "q5", + "sNJ8Z1ToBV0O": "tl", + "0iOg72QcPo": "RP", + "pSEqeL": "5KZ", + "GOyYqTgmvQ": "9", + "Qv4qCsj": "nl52E", + "wNQ5tOutYu5s8": "6iGa", + } + rep := NewReplacer(keywords) + text := "AoRxrdKWsGhFpXwVqMLWRL74OukwjBuBh0g7pSrk" + val := rep.Replace(text) + keys := rep.(*replacer).node.find([]rune(val)) + if len(keys) > 0 { + t.Errorf("result: %s, match: %v", val, keys) + } +} diff --git a/core/stringx/trie.go b/core/stringx/trie.go index 115b4c30..c1bf5446 100644 --- a/core/stringx/trie.go +++ b/core/stringx/trie.go @@ -39,6 +39,8 @@ func NewTrie(words []string, opts ...TrieOption) Trie { n.add(word) } + n.build() + return n } @@ -48,7 +50,7 @@ func (n *trieNode) Filter(text string) (sentence string, keywords []string, foun return text, nil, false } - scopes := n.findKeywordScopes(chars) + scopes := n.find(chars) keywords = n.collectKeywords(chars, scopes) for _, match := range scopes { @@ -65,7 +67,7 @@ func (n *trieNode) FindKeywords(text string) []string { return nil } - scopes := n.findKeywordScopes(chars) + scopes := n.find(chars) return n.collectKeywords(chars, scopes) } @@ -85,48 +87,6 @@ func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string { return keywords } -func (n *trieNode) findKeywordScopes(chars []rune) []scope { - var scopes []scope - size := len(chars) - start := -1 - - for i := 0; i < size; i++ { - child, ok := n.children[chars[i]] - if !ok { - continue - } - - if start < 0 { - start = i - } - if child.end { - scopes = append(scopes, scope{ - start: start, - stop: i + 1, - }) - } - - for j := i + 1; j < size; j++ { - grandchild, ok := child.children[chars[j]] - if !ok { - break - } - - child = grandchild - if child.end { - scopes = append(scopes, scope{ - start: start, - stop: j + 1, - }) - } - } - - start = -1 - } - - return scopes -} - func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) { for i := start; i < stop; i++ { chars[i] = n.mask diff --git a/core/stringx/trie_test.go b/core/stringx/trie_test.go index e74a822c..4d961ccf 100644 --- a/core/stringx/trie_test.go +++ b/core/stringx/trie_test.go @@ -6,6 +6,17 @@ import ( "github.com/stretchr/testify/assert" ) +func TestTrieSimple(t *testing.T) { + trie := NewTrie([]string{ + "bc", + "cd", + }) + output, keywords, found := trie.Filter("abcd") + assert.True(t, found) + assert.Equal(t, "a***", output) + assert.ElementsMatch(t, []string{"bc", "cd"}, keywords) +} + func TestTrie(t *testing.T) { tests := []struct { input string @@ -14,11 +25,11 @@ func TestTrie(t *testing.T) { found bool }{ { - input: "日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演", + input: "日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演", output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演", keywords: []string{ "AV演员", - "苍井空", + "无名氏", "AV", "日本AV女优", "AV演员色情", @@ -89,7 +100,7 @@ func TestTrie(t *testing.T) { "一不", "AV", "AV演员", - "苍井空", + "无名氏", "AV演员色情", "日本AV女优", }) @@ -145,20 +156,3 @@ func TestTrieNested(t *testing.T) { assert.True(t, ok) assert.Equal(t, "零########九十", output) } - -func BenchmarkTrie(b *testing.B) { - b.ReportAllocs() - - trie := NewTrie([]string{ - "A", - "AV", - "AV演员", - "苍井空", - "AV演员色情", - "日本AV女优", - }) - - for i := 0; i < b.N; i++ { - trie.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演") - } -} diff --git a/core/syncx/singleflight.go b/core/syncx/singleflight.go index bcdb450e..92ae3ada 100644 --- a/core/syncx/singleflight.go +++ b/core/syncx/singleflight.go @@ -3,10 +3,6 @@ package syncx import "sync" type ( - // SharedCalls is an alias of SingleFlight. - // Deprecated: use SingleFlight. - SharedCalls = SingleFlight - // SingleFlight lets the concurrent calls with the same key to share the call result. // For example, A called F, before it's done, B called F. Then B would not execute F, // and shared the result returned by F which called by A. @@ -37,12 +33,6 @@ func NewSingleFlight() SingleFlight { } } -// NewSharedCalls returns a SingleFlight. -// Deprecated: use NewSingleFlight. -func NewSharedCalls() SingleFlight { - return NewSingleFlight() -} - func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { c, done := g.createCall(key) if done {