1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testtrace
6
7import (
8	"bufio"
9	"bytes"
10	"fmt"
11	"regexp"
12	"strconv"
13	"strings"
14)
15
16// Expectation represents the expected result of some operation.
17type Expectation struct {
18	failure      bool
19	errorMatcher *regexp.Regexp
20}
21
22// ExpectSuccess returns an Expectation that trivially expects success.
23func ExpectSuccess() *Expectation {
24	return new(Expectation)
25}
26
27// Check validates whether err conforms to the expectation. Returns
28// an error if it does not conform.
29//
30// Conformance means that if failure is true, then err must be non-nil.
31// If err is non-nil, then it must match errorMatcher.
32func (e *Expectation) Check(err error) error {
33	if !e.failure && err != nil {
34		return fmt.Errorf("unexpected error while reading the trace: %v", err)
35	}
36	if e.failure && err == nil {
37		return fmt.Errorf("expected error while reading the trace: want something matching %q, got none", e.errorMatcher)
38	}
39	if e.failure && err != nil && !e.errorMatcher.MatchString(err.Error()) {
40		return fmt.Errorf("unexpected error while reading the trace: want something matching %q, got %s", e.errorMatcher, err.Error())
41	}
42	return nil
43}
44
45// ParseExpectation parses the serialized form of an Expectation.
46func ParseExpectation(data []byte) (*Expectation, error) {
47	exp := new(Expectation)
48	s := bufio.NewScanner(bytes.NewReader(data))
49	if s.Scan() {
50		c := strings.SplitN(s.Text(), " ", 2)
51		switch c[0] {
52		case "SUCCESS":
53		case "FAILURE":
54			exp.failure = true
55			if len(c) != 2 {
56				return exp, fmt.Errorf("bad header line for FAILURE: %q", s.Text())
57			}
58			matcher, err := parseMatcher(c[1])
59			if err != nil {
60				return exp, err
61			}
62			exp.errorMatcher = matcher
63		default:
64			return exp, fmt.Errorf("bad header line: %q", s.Text())
65		}
66		return exp, nil
67	}
68	return exp, s.Err()
69}
70
71func parseMatcher(quoted string) (*regexp.Regexp, error) {
72	pattern, err := strconv.Unquote(quoted)
73	if err != nil {
74		return nil, fmt.Errorf("malformed pattern: not correctly quoted: %s: %v", quoted, err)
75	}
76	matcher, err := regexp.Compile(pattern)
77	if err != nil {
78		return nil, fmt.Errorf("malformed pattern: not a valid regexp: %s: %v", pattern, err)
79	}
80	return matcher, nil
81}
82