xref: /aosp_15_r20/external/tink/go/keyderivation/internal/streamingprf/hkdf_streaming_prf_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package streamingprf
18
19import (
20	"bytes"
21	"encoding/hex"
22	"fmt"
23	"io"
24	"strings"
25	"testing"
26
27	"github.com/google/tink/go/subtle/random"
28	"github.com/google/tink/go/testutil"
29)
30
31func TestNewHKDFStreamingPRF(t *testing.T) {
32	for _, test := range []struct {
33		name string
34		hash string
35		salt []byte
36	}{
37		{
38			name: "SHA256_nil_salt",
39			hash: "SHA256",
40		},
41		{
42			name: "SHA256_random_salt",
43			hash: "SHA256",
44			salt: random.GetRandomBytes(16),
45		},
46		{
47			name: "SHA512_nil_salt",
48			hash: "SHA512",
49		},
50		{
51			name: "SHA512_random_salt",
52			hash: "SHA512",
53			salt: random.GetRandomBytes(16),
54		},
55	} {
56		t.Run(test.name, func(t *testing.T) {
57			key := random.GetRandomBytes(32)
58			h, err := newHKDFStreamingPRF(test.hash, key, test.salt)
59			if err != nil {
60				t.Fatalf("newHKDFStreamingPRF() err = %v, want nil", err)
61			}
62			if !bytes.Equal(h.key, key) {
63				t.Errorf("key = %v, want %v", h.key, key)
64			}
65			if !bytes.Equal(h.salt, test.salt) {
66				t.Errorf("salt = %v, want %v", h.salt, test.salt)
67			}
68		})
69	}
70}
71
72func TestNewHKDFStreamingPRFFails(t *testing.T) {
73	for _, test := range []struct {
74		hash    string
75		keySize uint32
76	}{
77		{
78			hash:    "SHA256",
79			keySize: 16,
80		},
81		{
82			hash:    "SHA512",
83			keySize: 16},
84		{
85			hash:    "SHA1",
86			keySize: 20,
87		},
88	} {
89		t.Run(test.hash, func(t *testing.T) {
90			if _, err := newHKDFStreamingPRF(test.hash, random.GetRandomBytes(test.keySize), nil); err == nil {
91				t.Error("newHKDFStreamingPRF() err = nil, want non-nil")
92			}
93		})
94	}
95}
96
97func TestHKDFStreamingPRFWithRFCVector(t *testing.T) {
98	// This is the only vector that uses an accepted hash function and has key
99	// size >= minHKDFStreamingPRFKeySize.
100	// https://www.rfc-editor.org/rfc/rfc5869#appendix-A.2
101	vec := struct {
102		hash   string
103		key    string
104		salt   string
105		info   string
106		outLen int
107		okm    string
108	}{
109		hash:   "SHA256",
110		key:    "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f",
111		salt:   "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
112		info:   "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
113		outLen: 82,
114		okm:    "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5c1f3434f1d87",
115	}
116	key, err := hex.DecodeString(vec.key)
117	if err != nil {
118		t.Fatalf("hex.DecodeString err = %v, want nil", err)
119	}
120	salt, err := hex.DecodeString(vec.salt)
121	if err != nil {
122		t.Fatalf("hex.DecodeString err = %v, want nil", err)
123	}
124	info, err := hex.DecodeString(vec.info)
125	if err != nil {
126		t.Fatalf("hex.DecodeString err = %v, want nil", err)
127	}
128
129	h, err := newHKDFStreamingPRF(vec.hash, key, salt)
130	if err != nil {
131		t.Fatalf("newHKDFStreamingPRF() err = %v, want nil", err)
132	}
133	r, err := h.Compute(info)
134	if err != nil {
135		t.Fatalf("Compute() err = %v, want nil", err)
136	}
137	out := make([]byte, vec.outLen)
138	if _, err := io.ReadAtLeast(r, out, len(out)); err != nil {
139		t.Fatalf("io.ReadAtLeast err = %v, want nil", err)
140	}
141	if hex.EncodeToString(out) != vec.okm {
142		t.Errorf("Compute() = %v, want %v", hex.EncodeToString(out), vec.okm)
143	}
144}
145
146func TestHKDFStreamingPRFWithWycheproof(t *testing.T) {
147	testutil.SkipTestIfTestSrcDirIsNotSet(t)
148
149	type hkdfCase struct {
150		testutil.WycheproofCase
151		IKM  testutil.HexBytes `json:"ikm"`
152		Salt testutil.HexBytes `json:"salt"`
153		Info testutil.HexBytes `json:"info"`
154		Size uint32            `json:"size"`
155		OKM  testutil.HexBytes `json:"okm"`
156	}
157	type hkdfGroup struct {
158		testutil.WycheproofGroup
159		KeySize uint32      `json:"keySize"`
160		Type    string      `json:"type"`
161		Tests   []*hkdfCase `json:"tests"`
162	}
163	type hkdfSuite struct {
164		testutil.WycheproofSuite
165		TestGroups []*hkdfGroup `json:"testGroups"`
166	}
167
168	count := 0
169	for _, hash := range []string{"SHA256", "SHA512"} {
170		filename := fmt.Sprintf("hkdf_%s_test.json", strings.ToLower(hash))
171		suite := new(hkdfSuite)
172		if err := testutil.PopulateSuite(suite, filename); err != nil {
173			t.Fatalf("testutil.PopulateSuite(%v, %s): %v", suite, filename, err)
174		}
175		for _, group := range suite.TestGroups {
176			for _, test := range group.Tests {
177				caseName := fmt.Sprintf("%s(%d):Case-%d", hash, group.KeySize, test.CaseID)
178				t.Run(caseName, func(t *testing.T) {
179					if got, want := len(test.IKM), int(group.KeySize/8); got != want {
180						t.Fatalf("invalid key length = %d, want %d", got, want)
181					}
182					count++
183
184					h, err := newHKDFStreamingPRF(hash, test.IKM, test.Salt)
185					switch test.Result {
186					case "valid":
187						if len(test.IKM) < minHKDFStreamingPRFKeySize {
188							if err == nil {
189								t.Error("newHKDFStreamingPRF err = nil, want non-nil")
190							}
191							return
192						}
193						if err != nil {
194							t.Fatalf("newHKDFStreamingPRF err = %v, want nil", err)
195						}
196						r, err := h.Compute(test.Info)
197						if err != nil {
198							t.Fatalf("Compute() err = %v, want nil", err)
199						}
200						out := make([]byte, test.Size)
201						if _, err := io.ReadAtLeast(r, out, len(out)); err != nil {
202							t.Fatalf("io.ReadAtLeast err = %v, want nil", err)
203						}
204						if !bytes.Equal(out, test.OKM) {
205							t.Errorf("Compute() = %v, want %v", out, test.OKM)
206						}
207
208					case "invalid":
209						if err != nil {
210							return
211						}
212						r, err := h.Compute(test.Info)
213						if err != nil {
214							t.Fatalf("Compute() err = %v, want nil", err)
215						}
216						out := make([]byte, test.Size)
217						if _, err := io.ReadAtLeast(r, out, len(out)); err == nil {
218							t.Error("io.ReadAtLeast err = nil, want non-nil")
219						}
220
221					default:
222						t.Errorf("unsupported test result: %s", test.Result)
223					}
224				})
225			}
226		}
227	}
228	if count < 200 {
229		t.Errorf("number of test cases = %d, want > 200", count)
230	}
231}
232