1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package encodecounter
6
7import (
8	"bufio"
9	"encoding/binary"
10	"fmt"
11	"internal/coverage"
12	"internal/coverage/slicewriter"
13	"internal/coverage/stringtab"
14	"internal/coverage/uleb128"
15	"io"
16	"os"
17	"slices"
18)
19
20// This package contains APIs and helpers for encoding initial portions
21// of the counter data files emitted at runtime when coverage instrumentation
22// is enabled.  Counter data files may contain multiple segments; the file
23// header and first segment are written via the "Write" method below, and
24// additional segments can then be added using "AddSegment".
25
26type CoverageDataWriter struct {
27	stab    *stringtab.Writer
28	w       *bufio.Writer
29	csh     coverage.CounterSegmentHeader
30	tmp     []byte
31	cflavor coverage.CounterFlavor
32	segs    uint32
33	debug   bool
34}
35
36func NewCoverageDataWriter(w io.Writer, flav coverage.CounterFlavor) *CoverageDataWriter {
37	r := &CoverageDataWriter{
38		stab: &stringtab.Writer{},
39		w:    bufio.NewWriter(w),
40
41		tmp:     make([]byte, 64),
42		cflavor: flav,
43	}
44	r.stab.InitWriter()
45	r.stab.Lookup("")
46	return r
47}
48
49// CounterVisitor describes a helper object used during counter file
50// writing; when writing counter data files, clients pass a
51// CounterVisitor to the write/emit routines, then the expectation is
52// that the VisitFuncs method will then invoke the callback "f" with
53// data for each function to emit to the file.
54type CounterVisitor interface {
55	VisitFuncs(f CounterVisitorFn) error
56}
57
58// CounterVisitorFn describes a callback function invoked when writing
59// coverage counter data.
60type CounterVisitorFn func(pkid uint32, funcid uint32, counters []uint32) error
61
62// Write writes the contents of the count-data file to the writer
63// previously supplied to NewCoverageDataWriter. Returns an error
64// if something went wrong somewhere with the write.
65func (cfw *CoverageDataWriter) Write(metaFileHash [16]byte, args map[string]string, visitor CounterVisitor) error {
66	if err := cfw.writeHeader(metaFileHash); err != nil {
67		return err
68	}
69	return cfw.AppendSegment(args, visitor)
70}
71
72func padToFourByteBoundary(ws *slicewriter.WriteSeeker) error {
73	sz := len(ws.BytesWritten())
74	zeros := []byte{0, 0, 0, 0}
75	rem := uint32(sz) % 4
76	if rem != 0 {
77		pad := zeros[:(4 - rem)]
78		if nw, err := ws.Write(pad); err != nil {
79			return err
80		} else if nw != len(pad) {
81			return fmt.Errorf("error: short write")
82		}
83	}
84	return nil
85}
86
87func (cfw *CoverageDataWriter) patchSegmentHeader(ws *slicewriter.WriteSeeker) error {
88	// record position
89	off, err := ws.Seek(0, io.SeekCurrent)
90	if err != nil {
91		return fmt.Errorf("error seeking in patchSegmentHeader: %v", err)
92	}
93	// seek back to start so that we can update the segment header
94	if _, err := ws.Seek(0, io.SeekStart); err != nil {
95		return fmt.Errorf("error seeking in patchSegmentHeader: %v", err)
96	}
97	if cfw.debug {
98		fmt.Fprintf(os.Stderr, "=-= writing counter segment header: %+v", cfw.csh)
99	}
100	if err := binary.Write(ws, binary.LittleEndian, cfw.csh); err != nil {
101		return err
102	}
103	// ... and finally return to the original offset.
104	if _, err := ws.Seek(off, io.SeekStart); err != nil {
105		return fmt.Errorf("error seeking in patchSegmentHeader: %v", err)
106	}
107	return nil
108}
109
110func (cfw *CoverageDataWriter) writeSegmentPreamble(args map[string]string, ws *slicewriter.WriteSeeker) error {
111	if err := binary.Write(ws, binary.LittleEndian, cfw.csh); err != nil {
112		return err
113	}
114	hdrsz := uint32(len(ws.BytesWritten()))
115
116	// Write string table and args to a byte slice (since we need
117	// to capture offsets at various points), then emit the slice
118	// once we are done.
119	cfw.stab.Freeze()
120	if err := cfw.stab.Write(ws); err != nil {
121		return err
122	}
123	cfw.csh.StrTabLen = uint32(len(ws.BytesWritten())) - hdrsz
124
125	akeys := make([]string, 0, len(args))
126	for k := range args {
127		akeys = append(akeys, k)
128	}
129	slices.Sort(akeys)
130
131	wrULEB128 := func(v uint) error {
132		cfw.tmp = cfw.tmp[:0]
133		cfw.tmp = uleb128.AppendUleb128(cfw.tmp, v)
134		if _, err := ws.Write(cfw.tmp); err != nil {
135			return err
136		}
137		return nil
138	}
139
140	// Count of arg pairs.
141	if err := wrULEB128(uint(len(args))); err != nil {
142		return err
143	}
144	// Arg pairs themselves.
145	for _, k := range akeys {
146		ki := uint(cfw.stab.Lookup(k))
147		if err := wrULEB128(ki); err != nil {
148			return err
149		}
150		v := args[k]
151		vi := uint(cfw.stab.Lookup(v))
152		if err := wrULEB128(vi); err != nil {
153			return err
154		}
155	}
156	if err := padToFourByteBoundary(ws); err != nil {
157		return err
158	}
159	cfw.csh.ArgsLen = uint32(len(ws.BytesWritten())) - (cfw.csh.StrTabLen + hdrsz)
160
161	return nil
162}
163
164// AppendSegment appends a new segment to a counter data, with a new
165// args section followed by a payload of counter data clauses.
166func (cfw *CoverageDataWriter) AppendSegment(args map[string]string, visitor CounterVisitor) error {
167	cfw.stab = &stringtab.Writer{}
168	cfw.stab.InitWriter()
169	cfw.stab.Lookup("")
170
171	var err error
172	for k, v := range args {
173		cfw.stab.Lookup(k)
174		cfw.stab.Lookup(v)
175	}
176
177	ws := &slicewriter.WriteSeeker{}
178	if err = cfw.writeSegmentPreamble(args, ws); err != nil {
179		return err
180	}
181	if err = cfw.writeCounters(visitor, ws); err != nil {
182		return err
183	}
184	if err = cfw.patchSegmentHeader(ws); err != nil {
185		return err
186	}
187	if err := cfw.writeBytes(ws.BytesWritten()); err != nil {
188		return err
189	}
190	if err = cfw.writeFooter(); err != nil {
191		return err
192	}
193	if err := cfw.w.Flush(); err != nil {
194		return fmt.Errorf("write error: %v", err)
195	}
196	cfw.stab = nil
197	return nil
198}
199
200func (cfw *CoverageDataWriter) writeHeader(metaFileHash [16]byte) error {
201	// Emit file header.
202	ch := coverage.CounterFileHeader{
203		Magic:     coverage.CovCounterMagic,
204		Version:   coverage.CounterFileVersion,
205		MetaHash:  metaFileHash,
206		CFlavor:   cfw.cflavor,
207		BigEndian: false,
208	}
209	if err := binary.Write(cfw.w, binary.LittleEndian, ch); err != nil {
210		return err
211	}
212	return nil
213}
214
215func (cfw *CoverageDataWriter) writeBytes(b []byte) error {
216	if len(b) == 0 {
217		return nil
218	}
219	nw, err := cfw.w.Write(b)
220	if err != nil {
221		return fmt.Errorf("error writing counter data: %v", err)
222	}
223	if len(b) != nw {
224		return fmt.Errorf("error writing counter data: short write")
225	}
226	return nil
227}
228
229func (cfw *CoverageDataWriter) writeCounters(visitor CounterVisitor, ws *slicewriter.WriteSeeker) error {
230	// Notes:
231	// - this version writes everything little-endian, which means
232	//   a call is needed to encode every value (expensive)
233	// - we may want to move to a model in which we just blast out
234	//   all counters, or possibly mmap the file and do the write
235	//   implicitly.
236	ctrb := make([]byte, 4)
237	wrval := func(val uint32) error {
238		var buf []byte
239		var towr int
240		if cfw.cflavor == coverage.CtrRaw {
241			binary.LittleEndian.PutUint32(ctrb, val)
242			buf = ctrb
243			towr = 4
244		} else if cfw.cflavor == coverage.CtrULeb128 {
245			cfw.tmp = cfw.tmp[:0]
246			cfw.tmp = uleb128.AppendUleb128(cfw.tmp, uint(val))
247			buf = cfw.tmp
248			towr = len(buf)
249		} else {
250			panic("internal error: bad counter flavor")
251		}
252		if sz, err := ws.Write(buf); err != nil {
253			return err
254		} else if sz != towr {
255			return fmt.Errorf("writing counters: short write")
256		}
257		return nil
258	}
259
260	// Write out entries for each live function.
261	emitter := func(pkid uint32, funcid uint32, counters []uint32) error {
262		cfw.csh.FcnEntries++
263		if err := wrval(uint32(len(counters))); err != nil {
264			return err
265		}
266
267		if err := wrval(pkid); err != nil {
268			return err
269		}
270
271		if err := wrval(funcid); err != nil {
272			return err
273		}
274		for _, val := range counters {
275			if err := wrval(val); err != nil {
276				return err
277			}
278		}
279		return nil
280	}
281	if err := visitor.VisitFuncs(emitter); err != nil {
282		return err
283	}
284	return nil
285}
286
287func (cfw *CoverageDataWriter) writeFooter() error {
288	cfw.segs++
289	cf := coverage.CounterFileFooter{
290		Magic:       coverage.CovCounterMagic,
291		NumSegments: cfw.segs,
292	}
293	if err := binary.Write(cfw.w, binary.LittleEndian, cf); err != nil {
294		return err
295	}
296	return nil
297}
298