xref: /aosp_15_r20/external/licenseclassifier/v2/searchset.go (revision 46c4c49da23cae783fa41bf46525a6505638499a)
1// Copyright 2020 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package classifier
16
17import (
18	"fmt"
19	"hash/crc32"
20	"math"
21	"sort"
22
23	"github.com/davecgh/go-spew/spew"
24)
25
26// searchSet is a set of q-grams that have hashes associated with them,
27// making it fast to search for potential matches.
28type searchSet struct {
29	// Tokens is a tokenized list of the original input string.
30	Tokens []indexedToken
31	// Hashes is a map of checksums to a range of tokens.
32	Hashes hash
33	// Checksums is a list of checksums ordered from longest range to
34	// shortest.
35	Checksums []uint32
36	// ChecksumRanges are the token ranges for the above checksums.
37	ChecksumRanges tokenRanges
38	origin         string // A debugging identifier to label what this searchset is associated with
39
40	nodes []*node
41	q     int // The length of q-grams in this searchset.
42}
43
44// node consists of a range of tokens along with the checksum for those tokens.
45type node struct {
46	checksum uint32
47	tokens   *tokenRange
48}
49
50func (n *node) String() string {
51	return fmt.Sprintf("[%d:%d]", n.tokens.Start, n.tokens.End)
52}
53
54// newSearchSet creates a new searchSet object. A searchset generates all
55// possible q-grams of tokens. These q-grams of tokens can be correlated to
56// determine where a section of text from one source may appear in another
57// source.
58func newSearchSet(s *indexedDocument, q int) *searchSet {
59	// Start generating hash values for all q-grams within the text.
60	h := make(hash)
61	if len(s.Tokens) < q {
62		// We can't have a smaller q than the number of tokens.
63		q = len(s.Tokens)
64	}
65	checksums, tokenRanges := generateHashes(h, q, s.Tokens, s.dict)
66	sset := &searchSet{
67		Tokens:         s.Tokens,
68		Hashes:         h,
69		Checksums:      checksums,
70		ChecksumRanges: tokenRanges,
71		q:              q,
72	}
73	sset.generateNodeList()
74	return sset
75}
76
77// tokenRange indicates the range of tokens that map to a particular checksum.
78type tokenRange struct {
79	Start int
80	End   int
81}
82
83func (t *tokenRange) String() string {
84	return fmt.Sprintf("[%v, %v)", t.Start, t.End)
85}
86
87// tokenRanges is a sortable type of a slice of TokenRange.
88type tokenRanges []*tokenRange
89
90// generateHashes computes a hash using CRC-32 for each q-gram encountered in the provided tokens.
91func generateHashes(h hash, q int, toks []indexedToken, dict *dictionary) ([]uint32, tokenRanges) {
92	if q == 0 {
93		return nil, nil
94	}
95	var css []uint32
96	var tr tokenRanges
97	crc := crc32.NewIEEE()
98	for offset := 0; offset+q <= len(toks); offset++ {
99		crc.Reset()
100		for i := 0; i < q; i++ {
101			crc.Write([]byte(dict.getWord(toks[offset+i].ID)))
102			crc.Write([]byte{' '})
103		}
104		cs := crc.Sum32()
105		css = append(css, cs)
106		tr = append(tr, &tokenRange{offset, offset + q})
107		h.add(cs, offset, offset+q)
108	}
109
110	return css, tr
111}
112
113// generateNodeList creates a node list out of the search set.
114func (s *searchSet) generateNodeList() {
115	if len(s.Tokens) == 0 {
116		return
117	}
118
119	for i := 0; i < len(s.Checksums); i++ {
120		s.nodes = append(s.nodes, &node{
121			checksum: s.Checksums[i],
122			tokens:   s.ChecksumRanges[i],
123		})
124	}
125}
126
127// matchRange is the range within the source text that is a match to the range
128// in the target text.
129type matchRange struct {
130	// Offsets into the source tokens.
131	SrcStart, SrcEnd int
132	// Offsets into the target tokens.
133	TargetStart, TargetEnd int
134	// TokensClaimed tracks the number of positively matched tokens in this
135	// range.  For initially created matchRanges, this is equal to the extent of
136	// the range.  However, as matchRanges get merged together and error is
137	// potentially introduced, this tracks the precise number of tokens that
138	// exist in the range.
139	TokensClaimed int
140}
141
142// in returns true if the start and end are enclosed in the match range.
143func (m *matchRange) in(other *matchRange) bool {
144	return m.TargetStart >= other.TargetStart && m.TargetEnd <= other.TargetEnd
145}
146
147func (m *matchRange) String() string {
148	return fmt.Sprintf("S: [%v, %v)-> T: [%v, %v) => %v [%v]", m.SrcStart, m.SrcEnd, m.TargetStart, m.TargetEnd, m.TargetStart-m.SrcStart, m.TokensClaimed)
149}
150
151// matchRanges is a list of "matchRange"s. The ranges are monotonically
152// increasing in value and indicate a single potential occurrence of the source
153// text in the target text. They are sorted to support greedy matching with the
154// longest runs of q-grams appearing first in the sort.
155type matchRanges []*matchRange
156
157func (m matchRanges) Len() int      { return len(m) }
158func (m matchRanges) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
159func (m matchRanges) Less(i, j int) bool {
160	if m[i].TokensClaimed != m[j].TokensClaimed {
161		return m[i].TokensClaimed > m[j].TokensClaimed
162	}
163
164	if m[i].TargetStart != m[j].TargetStart {
165		return m[i].TargetStart < m[j].TargetStart
166	}
167	return m[i].SrcStart < m[j].SrcStart
168}
169
170// findPotentialMatches returns the ranges in the target (unknown) text that
171// are best potential matches to the source (known) text.
172func (c *Classifier) findPotentialMatches(src, target *searchSet, confidence float64) matchRanges {
173	matchedRanges := c.getMatchedRanges(src, target, confidence, src.q)
174	if c.tc.traceSearchset(src.origin) {
175		c.tc.trace("matchedRanges = %s", spew.Sdump(matchedRanges))
176	}
177	if len(matchedRanges) == 0 {
178		return nil
179	}
180
181	// After computing all potential matches, we only output ranges that contain
182	// enough tokens to clear the confidence threshold. As noted, this number can
183	// be too high, yielding false positives, but cannot yield false negatives.
184	threshold := int(confidence * float64(len(src.Tokens)))
185
186	for i, m := range matchedRanges {
187		if m.TokensClaimed < threshold {
188			matchedRanges = matchedRanges[:i]
189			break
190		}
191	}
192
193	if c.tc.traceSearchset(src.origin) {
194		c.tc.trace("finalized matchedRanges for %s: %d = %s", src.origin, len(src.Tokens), spew.Sdump(matchedRanges))
195	}
196	return matchedRanges
197}
198
199// fuseRanges analyzes the source matches, attempting to combine hits without
200// errors into larger hits with tolerable amounts of error to produce matches
201// that contain enough tokens to be considered for exact matching against a a
202// target document. This routine intentionally does not accurately track error
203// contributions from merging runs, trading false positives (but not false
204// negatives), for faster performance.
205func (c *Classifier) fuseRanges(origin string, matched matchRanges, confidence float64, size int, runs []matchRange, targetSize int) matchRanges {
206	var claimed matchRanges
207	errorMargin := int(math.Round(float64(size) * (1.0 - confidence)))
208
209	filter := make([]bool, targetSize)
210	for _, m := range runs {
211		for i := m.SrcStart; i < m.SrcEnd; i++ {
212			// Only consider these offsets if they fit into the target document, since
213			// the target may be smaller than the source being analyzed
214			if i < targetSize {
215				filter[i] = true
216			}
217		}
218	}
219
220	filterDrops := 0
221	filterPasses := 0
222
223	// For each hit detected, compare it against all other previous hits to see if it can be part of match
224	// or represents a group that is eligible for matching and having other hits contribute to it.
225	for _, m := range matched {
226		off := m.TargetStart - m.SrcStart
227
228		// If the offset is negative, but within error margins, we associate it
229		// with the first index to see if it could contribute to a run. If the
230		// absolute offset is larger than the error margin, it can't possibly
231		// contribute and will be dropped. This puts more potential error into the zero
232		// index, but that just slightly increases the rate of false positives. In
233		// practice, this would only be an issue if there are major substrings of a
234		// source in a target that aren't part of a real hit. We see many small
235		// references (the name of a license) but not large chunks of the license.
236		if off < 0 {
237			if -off <= errorMargin {
238				off = 0
239			} else {
240				continue
241			}
242		}
243
244		// If the filter is false, there was not sufficient token density in that
245		// part of the target document for a viable match, so this match is a
246		// spurious hit and can be discarded.
247		if !filter[off] {
248			filterDrops++
249			continue
250		}
251
252		filterPasses++
253		unclaimed := true
254
255		for _, c := range claimed {
256			moff := m.TargetStart - m.SrcStart
257			coff := c.TargetStart - c.SrcStart
258			sampleError := int(math.Round(math.Abs(float64(moff - coff))))
259			withinError := sampleError < errorMargin
260
261			// The contribution needs to add more value than accumulated error. This prevents
262			// against spurious matches of a reference to a license incorrectly overextending the
263			// match range.
264			if withinError && m.TokensClaimed > int(sampleError) {
265				if m.in(c) {
266					// This can cause too many tokens to be claimed, but that's fine since we want to avoid
267					// undercounting and missing content.
268					c.TokensClaimed += m.TokensClaimed
269					unclaimed = false
270				} else {
271					// See if the claims can be merged. If the error tolerances allow for it,
272					// we merge the claims and the ranges. This doesn't accumulate error
273					// accurately so it's possible that repeated merges can introduce too
274					// much error to actually make a match, but we won't get false
275					// negatives from this approach.  The error tolerances allow for a
276					// merge, but we only want to merge if it increases the range of
277					// tokens being covered. If this match is already handled by an
278					// existing (stronger by definition) claim, we don't merge this one,
279					// but treat it as a new claim. This allows for the case where a
280					// highly fragmented text will be matched by a long series of low
281					// score matches.
282					if m.TargetStart < c.TargetStart && m.SrcStart < c.SrcStart {
283						c.TargetStart = m.TargetStart
284						c.SrcStart = m.SrcStart
285						c.TokensClaimed += m.TokensClaimed
286						unclaimed = false
287					} else if m.TargetEnd > c.TargetEnd && m.SrcEnd > c.SrcEnd {
288						c.TargetEnd = m.TargetEnd
289						c.SrcEnd = m.SrcEnd
290						c.TokensClaimed += m.TokensClaimed
291						unclaimed = false
292					}
293					// This claim does not extend any existing block, and it may be claimed in its own
294					// right.
295				}
296			}
297			if !unclaimed {
298				break
299			}
300		}
301		// Only create a claim if this claim is likely relevant. If we had some higher quality hits,
302		// it's likely this is spurious noise. If we haven't had any significantly better hits, we'll keep
303		// this around.
304		if unclaimed && m.TokensClaimed*10 > matched[0].TokensClaimed {
305			claimed = append(claimed, m)
306		}
307	}
308	sort.Sort(claimed)
309	if c.tc.traceSearchset(origin) {
310		c.tc.trace("filterPasses = %+v", filterPasses)
311		c.tc.trace("filterDrops = %+v", filterDrops)
312		c.tc.trace("claimed = %s", spew.Sdump(claimed))
313	}
314	return claimed
315}
316
317// getMatchedRanges finds the ranges in the target text that match the source
318// text. The ranges returned are ordered from the entries with the most matched
319// tokens to the least.
320func (c *Classifier) getMatchedRanges(src, target *searchSet, confidence float64, q int) matchRanges {
321	shouldTrace := c.tc.traceSearchset(src.origin)
322
323	if shouldTrace {
324		c.tc.trace("src.origin = %+v", src.origin)
325	}
326	// Assemble a list of all the matched q-grams without any consideration to
327	// error tolerances.
328	matched := targetMatchedRanges(src, target)
329	if shouldTrace {
330		c.tc.trace("matched = %s", spew.Sdump(matched))
331	}
332	if len(matched) == 0 {
333		return nil
334	}
335
336	// Perform neighborhood matching to figure out which clusters of q-grams have
337	// sufficient likelihood to be a potential match to the source. For an error
338	// confidence threshold of X, we require that a sequence of N target tokens
339	// must contain N*X (X <=1.0) ordered source tokens in order to be a viable
340	// match.
341	//
342	// It's much easier to compute potential ranges in the target if we disregard
343	// proper ordering of source tokens initially, and see which source q-grams
344	// appear in sufficient quantities to be a potential match. We can then
345	// disregard any q-gram that falls outside of that range. This helps
346	// significantly since processing token matches is an N^2 (or worse)
347	// operation, so reducing N is a big win.
348
349	runs := c.detectRuns(src.origin, matched, len(target.Tokens), len(src.Tokens), confidence, q)
350
351	if shouldTrace {
352		c.tc.trace("runs = %d: %s", len(runs), spew.Sdump(runs))
353	}
354
355	// If there are no target runs of source tokens, we're done.
356	if len(runs) == 0 {
357		return nil
358	}
359
360	// Using the runs as a filter to ignore irrelevant matches, fuse the source
361	// match ranges into larger matches (with possible errors) to see if we can
362	// produce large enough runs that pass the confidence threshold.
363
364	fr := c.fuseRanges(src.origin, matched, confidence, len(src.Tokens), runs, len(target.Tokens))
365	if shouldTrace {
366		c.tc.trace("fr = %s", spew.Sdump(fr))
367	}
368	return fr
369}
370
371func (c *Classifier) detectRuns(origin string, matched matchRanges, targetLength, subsetLength int, threshold float64, q int) []matchRange {
372	shouldTrace := c.tc.traceSearchset(origin)
373	hits := make([]bool, targetLength)
374	for _, m := range matched {
375		for idx := m.TargetStart; idx < m.TargetEnd; idx++ {
376			hits[idx] = true
377		}
378	}
379
380	if len(hits) == 0 {
381		return nil
382	}
383	var out []int
384
385	total := 0
386	target := int(float64(subsetLength) * threshold)
387	if shouldTrace {
388		c.tc.trace("target = %+v", target)
389		c.tc.trace("targetLength = %+v", targetLength)
390		c.tc.trace("subsetLength = %+v", subsetLength)
391	}
392
393	// If we don't have at least 1 subset (i.e. the target is shorter than the
394	// source) just analyze what we have.
395	if len(hits) < subsetLength {
396		if shouldTrace {
397			c.tc.trace("trimmed search length from %d to %d", subsetLength, len(hits))
398		}
399		subsetLength = len(hits)
400	}
401	// Initialize our sliding window value.
402	for i := 0; i < subsetLength; i++ {
403		if hits[i] {
404			total++
405		}
406	}
407
408	if total >= target {
409		out = append(out, 0)
410	}
411
412	// Now move through the window adjusting the total by subtracting out the
413	// last bit and adding in the new bit.
414	for i := 1; i < len(hits); i++ {
415		if hits[i-1] {
416			total--
417		}
418		end := i + subsetLength - 1
419		if end < len(hits) && hits[i+subsetLength-1] {
420			total++
421		}
422		if total >= target {
423			out = append(out, i)
424		}
425	}
426	if len(out) == 0 {
427		return nil
428	}
429
430	final := []matchRange{
431		{
432			SrcStart: out[0],
433			SrcEnd:   out[0] + q,
434		},
435	}
436	for i := 1; i < len(out); i++ {
437		if out[i] != 1+out[i-1] {
438			final = append(final, matchRange{
439				SrcStart: out[i],
440				SrcEnd:   out[i] + q,
441			})
442		} else {
443			final[len(final)-1].SrcEnd = out[i] + q
444		}
445	}
446
447	return final
448}
449
450func targetMatchedRanges(src, target *searchSet) matchRanges {
451	offsetMappings := make(map[int][]*matchRange)
452
453	var matched matchRanges
454	for _, tgtNode := range target.nodes {
455		sr, ok := src.Hashes[tgtNode.checksum]
456		if !ok {
457			continue
458		}
459
460		tv := tgtNode.tokens
461		for _, sv := range sr {
462			offset := tv.Start - sv.Start
463			if om, ok := offsetMappings[offset]; ok {
464				// See if this extends the most recent existing mapping
465				lastIdx := len(om) - 1
466				if om[lastIdx].TargetEnd == tv.End-1 {
467					// This new value extends. Update the value in place
468					om[lastIdx].SrcEnd = sv.End
469					om[lastIdx].TargetEnd = tv.End
470					continue
471				}
472			}
473			offsetMappings[offset] = append(offsetMappings[offset], &matchRange{
474				SrcStart:    sv.Start,
475				SrcEnd:      sv.End,
476				TargetStart: tv.Start,
477				TargetEnd:   tv.End,
478			})
479		}
480	}
481
482	// Compute the number of tokens claimed in each run and flatten into a single slice.
483	for _, mr := range offsetMappings {
484		for _, m := range mr {
485			m.TokensClaimed = m.TargetEnd - m.TargetStart
486		}
487		matched = append(matched, mr...)
488	}
489	sort.Sort(matched)
490	return matched
491}
492
493type hash map[uint32]tokenRanges
494
495func (h hash) add(checksum uint32, start, end int) {
496	h[checksum] = append(h[checksum], &tokenRange{Start: start, End: end})
497}
498