1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package strings
6
7import (
8	"io"
9	"sync"
10)
11
12// Replacer replaces a list of strings with replacements.
13// It is safe for concurrent use by multiple goroutines.
14type Replacer struct {
15	once   sync.Once // guards buildOnce method
16	r      replacer
17	oldnew []string
18}
19
20// replacer is the interface that a replacement algorithm needs to implement.
21type replacer interface {
22	Replace(s string) string
23	WriteString(w io.Writer, s string) (n int, err error)
24}
25
26// NewReplacer returns a new [Replacer] from a list of old, new string
27// pairs. Replacements are performed in the order they appear in the
28// target string, without overlapping matches. The old string
29// comparisons are done in argument order.
30//
31// NewReplacer panics if given an odd number of arguments.
32func NewReplacer(oldnew ...string) *Replacer {
33	if len(oldnew)%2 == 1 {
34		panic("strings.NewReplacer: odd argument count")
35	}
36	return &Replacer{oldnew: append([]string(nil), oldnew...)}
37}
38
39func (r *Replacer) buildOnce() {
40	r.r = r.build()
41	r.oldnew = nil
42}
43
44func (b *Replacer) build() replacer {
45	oldnew := b.oldnew
46	if len(oldnew) == 2 && len(oldnew[0]) > 1 {
47		return makeSingleStringReplacer(oldnew[0], oldnew[1])
48	}
49
50	allNewBytes := true
51	for i := 0; i < len(oldnew); i += 2 {
52		if len(oldnew[i]) != 1 {
53			return makeGenericReplacer(oldnew)
54		}
55		if len(oldnew[i+1]) != 1 {
56			allNewBytes = false
57		}
58	}
59
60	if allNewBytes {
61		r := byteReplacer{}
62		for i := range r {
63			r[i] = byte(i)
64		}
65		// The first occurrence of old->new map takes precedence
66		// over the others with the same old string.
67		for i := len(oldnew) - 2; i >= 0; i -= 2 {
68			o := oldnew[i][0]
69			n := oldnew[i+1][0]
70			r[o] = n
71		}
72		return &r
73	}
74
75	r := byteStringReplacer{toReplace: make([]string, 0, len(oldnew)/2)}
76	// The first occurrence of old->new map takes precedence
77	// over the others with the same old string.
78	for i := len(oldnew) - 2; i >= 0; i -= 2 {
79		o := oldnew[i][0]
80		n := oldnew[i+1]
81		// To avoid counting repetitions multiple times.
82		if r.replacements[o] == nil {
83			// We need to use string([]byte{o}) instead of string(o),
84			// to avoid utf8 encoding of o.
85			// E. g. byte(150) produces string of length 2.
86			r.toReplace = append(r.toReplace, string([]byte{o}))
87		}
88		r.replacements[o] = []byte(n)
89
90	}
91	return &r
92}
93
94// Replace returns a copy of s with all replacements performed.
95func (r *Replacer) Replace(s string) string {
96	r.once.Do(r.buildOnce)
97	return r.r.Replace(s)
98}
99
100// WriteString writes s to w with all replacements performed.
101func (r *Replacer) WriteString(w io.Writer, s string) (n int, err error) {
102	r.once.Do(r.buildOnce)
103	return r.r.WriteString(w, s)
104}
105
106// trieNode is a node in a lookup trie for prioritized key/value pairs. Keys
107// and values may be empty. For example, the trie containing keys "ax", "ay",
108// "bcbc", "x" and "xy" could have eight nodes:
109//
110//	n0  -
111//	n1  a-
112//	n2  .x+
113//	n3  .y+
114//	n4  b-
115//	n5  .cbc+
116//	n6  x+
117//	n7  .y+
118//
119// n0 is the root node, and its children are n1, n4 and n6; n1's children are
120// n2 and n3; n4's child is n5; n6's child is n7. Nodes n0, n1 and n4 (marked
121// with a trailing "-") are partial keys, and nodes n2, n3, n5, n6 and n7
122// (marked with a trailing "+") are complete keys.
123type trieNode struct {
124	// value is the value of the trie node's key/value pair. It is empty if
125	// this node is not a complete key.
126	value string
127	// priority is the priority (higher is more important) of the trie node's
128	// key/value pair; keys are not necessarily matched shortest- or longest-
129	// first. Priority is positive if this node is a complete key, and zero
130	// otherwise. In the example above, positive/zero priorities are marked
131	// with a trailing "+" or "-".
132	priority int
133
134	// A trie node may have zero, one or more child nodes:
135	//  * if the remaining fields are zero, there are no children.
136	//  * if prefix and next are non-zero, there is one child in next.
137	//  * if table is non-zero, it defines all the children.
138	//
139	// Prefixes are preferred over tables when there is one child, but the
140	// root node always uses a table for lookup efficiency.
141
142	// prefix is the difference in keys between this trie node and the next.
143	// In the example above, node n4 has prefix "cbc" and n4's next node is n5.
144	// Node n5 has no children and so has zero prefix, next and table fields.
145	prefix string
146	next   *trieNode
147
148	// table is a lookup table indexed by the next byte in the key, after
149	// remapping that byte through genericReplacer.mapping to create a dense
150	// index. In the example above, the keys only use 'a', 'b', 'c', 'x' and
151	// 'y', which remap to 0, 1, 2, 3 and 4. All other bytes remap to 5, and
152	// genericReplacer.tableSize will be 5. Node n0's table will be
153	// []*trieNode{ 0:n1, 1:n4, 3:n6 }, where the 0, 1 and 3 are the remapped
154	// 'a', 'b' and 'x'.
155	table []*trieNode
156}
157
158func (t *trieNode) add(key, val string, priority int, r *genericReplacer) {
159	if key == "" {
160		if t.priority == 0 {
161			t.value = val
162			t.priority = priority
163		}
164		return
165	}
166
167	if t.prefix != "" {
168		// Need to split the prefix among multiple nodes.
169		var n int // length of the longest common prefix
170		for ; n < len(t.prefix) && n < len(key); n++ {
171			if t.prefix[n] != key[n] {
172				break
173			}
174		}
175		if n == len(t.prefix) {
176			t.next.add(key[n:], val, priority, r)
177		} else if n == 0 {
178			// First byte differs, start a new lookup table here. Looking up
179			// what is currently t.prefix[0] will lead to prefixNode, and
180			// looking up key[0] will lead to keyNode.
181			var prefixNode *trieNode
182			if len(t.prefix) == 1 {
183				prefixNode = t.next
184			} else {
185				prefixNode = &trieNode{
186					prefix: t.prefix[1:],
187					next:   t.next,
188				}
189			}
190			keyNode := new(trieNode)
191			t.table = make([]*trieNode, r.tableSize)
192			t.table[r.mapping[t.prefix[0]]] = prefixNode
193			t.table[r.mapping[key[0]]] = keyNode
194			t.prefix = ""
195			t.next = nil
196			keyNode.add(key[1:], val, priority, r)
197		} else {
198			// Insert new node after the common section of the prefix.
199			next := &trieNode{
200				prefix: t.prefix[n:],
201				next:   t.next,
202			}
203			t.prefix = t.prefix[:n]
204			t.next = next
205			next.add(key[n:], val, priority, r)
206		}
207	} else if t.table != nil {
208		// Insert into existing table.
209		m := r.mapping[key[0]]
210		if t.table[m] == nil {
211			t.table[m] = new(trieNode)
212		}
213		t.table[m].add(key[1:], val, priority, r)
214	} else {
215		t.prefix = key
216		t.next = new(trieNode)
217		t.next.add("", val, priority, r)
218	}
219}
220
221func (r *genericReplacer) lookup(s string, ignoreRoot bool) (val string, keylen int, found bool) {
222	// Iterate down the trie to the end, and grab the value and keylen with
223	// the highest priority.
224	bestPriority := 0
225	node := &r.root
226	n := 0
227	for node != nil {
228		if node.priority > bestPriority && !(ignoreRoot && node == &r.root) {
229			bestPriority = node.priority
230			val = node.value
231			keylen = n
232			found = true
233		}
234
235		if s == "" {
236			break
237		}
238		if node.table != nil {
239			index := r.mapping[s[0]]
240			if int(index) == r.tableSize {
241				break
242			}
243			node = node.table[index]
244			s = s[1:]
245			n++
246		} else if node.prefix != "" && HasPrefix(s, node.prefix) {
247			n += len(node.prefix)
248			s = s[len(node.prefix):]
249			node = node.next
250		} else {
251			break
252		}
253	}
254	return
255}
256
257// genericReplacer is the fully generic algorithm.
258// It's used as a fallback when nothing faster can be used.
259type genericReplacer struct {
260	root trieNode
261	// tableSize is the size of a trie node's lookup table. It is the number
262	// of unique key bytes.
263	tableSize int
264	// mapping maps from key bytes to a dense index for trieNode.table.
265	mapping [256]byte
266}
267
268func makeGenericReplacer(oldnew []string) *genericReplacer {
269	r := new(genericReplacer)
270	// Find each byte used, then assign them each an index.
271	for i := 0; i < len(oldnew); i += 2 {
272		key := oldnew[i]
273		for j := 0; j < len(key); j++ {
274			r.mapping[key[j]] = 1
275		}
276	}
277
278	for _, b := range r.mapping {
279		r.tableSize += int(b)
280	}
281
282	var index byte
283	for i, b := range r.mapping {
284		if b == 0 {
285			r.mapping[i] = byte(r.tableSize)
286		} else {
287			r.mapping[i] = index
288			index++
289		}
290	}
291	// Ensure root node uses a lookup table (for performance).
292	r.root.table = make([]*trieNode, r.tableSize)
293
294	for i := 0; i < len(oldnew); i += 2 {
295		r.root.add(oldnew[i], oldnew[i+1], len(oldnew)-i, r)
296	}
297	return r
298}
299
300type appendSliceWriter []byte
301
302// Write writes to the buffer to satisfy [io.Writer].
303func (w *appendSliceWriter) Write(p []byte) (int, error) {
304	*w = append(*w, p...)
305	return len(p), nil
306}
307
308// WriteString writes to the buffer without string->[]byte->string allocations.
309func (w *appendSliceWriter) WriteString(s string) (int, error) {
310	*w = append(*w, s...)
311	return len(s), nil
312}
313
314type stringWriter struct {
315	w io.Writer
316}
317
318func (w stringWriter) WriteString(s string) (int, error) {
319	return w.w.Write([]byte(s))
320}
321
322func getStringWriter(w io.Writer) io.StringWriter {
323	sw, ok := w.(io.StringWriter)
324	if !ok {
325		sw = stringWriter{w}
326	}
327	return sw
328}
329
330func (r *genericReplacer) Replace(s string) string {
331	buf := make(appendSliceWriter, 0, len(s))
332	r.WriteString(&buf, s)
333	return string(buf)
334}
335
336func (r *genericReplacer) WriteString(w io.Writer, s string) (n int, err error) {
337	sw := getStringWriter(w)
338	var last, wn int
339	var prevMatchEmpty bool
340	for i := 0; i <= len(s); {
341		// Fast path: s[i] is not a prefix of any pattern.
342		if i != len(s) && r.root.priority == 0 {
343			index := int(r.mapping[s[i]])
344			if index == r.tableSize || r.root.table[index] == nil {
345				i++
346				continue
347			}
348		}
349
350		// Ignore the empty match iff the previous loop found the empty match.
351		val, keylen, match := r.lookup(s[i:], prevMatchEmpty)
352		prevMatchEmpty = match && keylen == 0
353		if match {
354			wn, err = sw.WriteString(s[last:i])
355			n += wn
356			if err != nil {
357				return
358			}
359			wn, err = sw.WriteString(val)
360			n += wn
361			if err != nil {
362				return
363			}
364			i += keylen
365			last = i
366			continue
367		}
368		i++
369	}
370	if last != len(s) {
371		wn, err = sw.WriteString(s[last:])
372		n += wn
373	}
374	return
375}
376
377// singleStringReplacer is the implementation that's used when there is only
378// one string to replace (and that string has more than one byte).
379type singleStringReplacer struct {
380	finder *stringFinder
381	// value is the new string that replaces that pattern when it's found.
382	value string
383}
384
385func makeSingleStringReplacer(pattern string, value string) *singleStringReplacer {
386	return &singleStringReplacer{finder: makeStringFinder(pattern), value: value}
387}
388
389func (r *singleStringReplacer) Replace(s string) string {
390	var buf Builder
391	i, matched := 0, false
392	for {
393		match := r.finder.next(s[i:])
394		if match == -1 {
395			break
396		}
397		matched = true
398		buf.Grow(match + len(r.value))
399		buf.WriteString(s[i : i+match])
400		buf.WriteString(r.value)
401		i += match + len(r.finder.pattern)
402	}
403	if !matched {
404		return s
405	}
406	buf.WriteString(s[i:])
407	return buf.String()
408}
409
410func (r *singleStringReplacer) WriteString(w io.Writer, s string) (n int, err error) {
411	sw := getStringWriter(w)
412	var i, wn int
413	for {
414		match := r.finder.next(s[i:])
415		if match == -1 {
416			break
417		}
418		wn, err = sw.WriteString(s[i : i+match])
419		n += wn
420		if err != nil {
421			return
422		}
423		wn, err = sw.WriteString(r.value)
424		n += wn
425		if err != nil {
426			return
427		}
428		i += match + len(r.finder.pattern)
429	}
430	wn, err = sw.WriteString(s[i:])
431	n += wn
432	return
433}
434
435// byteReplacer is the implementation that's used when all the "old"
436// and "new" values are single ASCII bytes.
437// The array contains replacement bytes indexed by old byte.
438type byteReplacer [256]byte
439
440func (r *byteReplacer) Replace(s string) string {
441	var buf []byte // lazily allocated
442	for i := 0; i < len(s); i++ {
443		b := s[i]
444		if r[b] != b {
445			if buf == nil {
446				buf = []byte(s)
447			}
448			buf[i] = r[b]
449		}
450	}
451	if buf == nil {
452		return s
453	}
454	return string(buf)
455}
456
457func (r *byteReplacer) WriteString(w io.Writer, s string) (n int, err error) {
458	sw := getStringWriter(w)
459	last := 0
460	for i := 0; i < len(s); i++ {
461		b := s[i]
462		if r[b] == b {
463			continue
464		}
465		if last != i {
466			wn, err := sw.WriteString(s[last:i])
467			n += wn
468			if err != nil {
469				return n, err
470			}
471		}
472		last = i + 1
473		nw, err := w.Write(r[b : int(b)+1])
474		n += nw
475		if err != nil {
476			return n, err
477		}
478	}
479	if last != len(s) {
480		nw, err := sw.WriteString(s[last:])
481		n += nw
482		if err != nil {
483			return n, err
484		}
485	}
486	return n, nil
487}
488
489// byteStringReplacer is the implementation that's used when all the
490// "old" values are single ASCII bytes but the "new" values vary in size.
491type byteStringReplacer struct {
492	// replacements contains replacement byte slices indexed by old byte.
493	// A nil []byte means that the old byte should not be replaced.
494	replacements [256][]byte
495	// toReplace keeps a list of bytes to replace. Depending on length of toReplace
496	// and length of target string it may be faster to use Count, or a plain loop.
497	// We store single byte as a string, because Count takes a string.
498	toReplace []string
499}
500
501// countCutOff controls the ratio of a string length to a number of replacements
502// at which (*byteStringReplacer).Replace switches algorithms.
503// For strings with higher ration of length to replacements than that value,
504// we call Count, for each replacement from toReplace.
505// For strings, with a lower ratio we use simple loop, because of Count overhead.
506// countCutOff is an empirically determined overhead multiplier.
507// TODO(tocarip) revisit once we have register-based abi/mid-stack inlining.
508const countCutOff = 8
509
510func (r *byteStringReplacer) Replace(s string) string {
511	newSize := len(s)
512	anyChanges := false
513	// Is it faster to use Count?
514	if len(r.toReplace)*countCutOff <= len(s) {
515		for _, x := range r.toReplace {
516			if c := Count(s, x); c != 0 {
517				// The -1 is because we are replacing 1 byte with len(replacements[b]) bytes.
518				newSize += c * (len(r.replacements[x[0]]) - 1)
519				anyChanges = true
520			}
521
522		}
523	} else {
524		for i := 0; i < len(s); i++ {
525			b := s[i]
526			if r.replacements[b] != nil {
527				// See above for explanation of -1
528				newSize += len(r.replacements[b]) - 1
529				anyChanges = true
530			}
531		}
532	}
533	if !anyChanges {
534		return s
535	}
536	buf := make([]byte, newSize)
537	j := 0
538	for i := 0; i < len(s); i++ {
539		b := s[i]
540		if r.replacements[b] != nil {
541			j += copy(buf[j:], r.replacements[b])
542		} else {
543			buf[j] = b
544			j++
545		}
546	}
547	return string(buf)
548}
549
550func (r *byteStringReplacer) WriteString(w io.Writer, s string) (n int, err error) {
551	sw := getStringWriter(w)
552	last := 0
553	for i := 0; i < len(s); i++ {
554		b := s[i]
555		if r.replacements[b] == nil {
556			continue
557		}
558		if last != i {
559			nw, err := sw.WriteString(s[last:i])
560			n += nw
561			if err != nil {
562				return n, err
563			}
564		}
565		last = i + 1
566		nw, err := w.Write(r.replacements[b])
567		n += nw
568		if err != nil {
569			return n, err
570		}
571	}
572	if last != len(s) {
573		var nw int
574		nw, err = sw.WriteString(s[last:])
575		n += nw
576	}
577	return
578}
579