1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"encoding/pem"
9	"fmt"
10	"runtime"
11	"testing"
12	"time"
13)
14
15func TestCertCache(t *testing.T) {
16	cc := certCache{}
17	p, _ := pem.Decode([]byte(rsaCertPEM))
18	if p == nil {
19		t.Fatal("Failed to decode certificate")
20	}
21
22	certA, err := cc.newCert(p.Bytes)
23	if err != nil {
24		t.Fatalf("newCert failed: %s", err)
25	}
26	certB, err := cc.newCert(p.Bytes)
27	if err != nil {
28		t.Fatalf("newCert failed: %s", err)
29	}
30	if certA.cert != certB.cert {
31		t.Fatal("newCert returned a unique reference for a duplicate certificate")
32	}
33
34	if entry, ok := cc.Load(string(p.Bytes)); !ok {
35		t.Fatal("cache does not contain expected entry")
36	} else {
37		if refs := entry.(*cacheEntry).refs.Load(); refs != 2 {
38			t.Fatalf("unexpected number of references: got %d, want 2", refs)
39		}
40	}
41
42	timeoutRefCheck := func(t *testing.T, key string, count int64) {
43		t.Helper()
44		c := time.After(4 * time.Second)
45		for {
46			select {
47			case <-c:
48				t.Fatal("timed out waiting for expected ref count")
49			default:
50				e, ok := cc.Load(key)
51				if !ok && count != 0 {
52					t.Fatal("cache does not contain expected key")
53				} else if count == 0 && !ok {
54					return
55				}
56
57				if e.(*cacheEntry).refs.Load() == count {
58					return
59				}
60			}
61		}
62	}
63
64	// Keep certA alive until at least now, so that we can
65	// purposefully nil it and force the finalizer to be
66	// called.
67	runtime.KeepAlive(certA)
68	certA = nil
69	runtime.GC()
70
71	timeoutRefCheck(t, string(p.Bytes), 1)
72
73	// Keep certB alive until at least now, so that we can
74	// purposefully nil it and force the finalizer to be
75	// called.
76	runtime.KeepAlive(certB)
77	certB = nil
78	runtime.GC()
79
80	timeoutRefCheck(t, string(p.Bytes), 0)
81}
82
83func BenchmarkCertCache(b *testing.B) {
84	p, _ := pem.Decode([]byte(rsaCertPEM))
85	if p == nil {
86		b.Fatal("Failed to decode certificate")
87	}
88
89	cc := certCache{}
90	b.ReportAllocs()
91	b.ResetTimer()
92	// We expect that calling newCert additional times after
93	// the initial call should not cause additional allocations.
94	for extra := 0; extra < 4; extra++ {
95		b.Run(fmt.Sprint(extra), func(b *testing.B) {
96			actives := make([]*activeCert, extra+1)
97			b.ResetTimer()
98			for i := 0; i < b.N; i++ {
99				var err error
100				actives[0], err = cc.newCert(p.Bytes)
101				if err != nil {
102					b.Fatal(err)
103				}
104				for j := 0; j < extra; j++ {
105					actives[j+1], err = cc.newCert(p.Bytes)
106					if err != nil {
107						b.Fatal(err)
108					}
109				}
110				for j := 0; j < extra+1; j++ {
111					actives[j] = nil
112				}
113				runtime.GC()
114			}
115		})
116	}
117}
118