xref: /aosp_15_r20/external/tink/go/hybrid/internal/hpke/hkdf_kdf_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hpke
18
19import (
20	"bytes"
21	"encoding/hex"
22	"fmt"
23	"math"
24	"testing"
25
26	"github.com/google/tink/go/subtle"
27)
28
29// TODO(b/201070904): Write tests using baseModeX25519HKDFSHA256Vectors.
30func TestHKDFKDFLabeledExtract(t *testing.T) {
31	kdf, err := newKDF(hkdfSHA256)
32	if err != nil {
33		t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
34	}
35	id, v := internetDraftVector(t)
36	suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID)
37
38	// Base mode uses a default empty value for the pre-shared key (PSK), see
39	// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4.
40	pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID)
41	infoHash := kdf.labeledExtract(emptySalt, v.info, "info_hash", suiteID)
42	keyScheduleCtx := keyScheduleContext(id.mode, pskIDHash, infoHash)
43	if !bytes.Equal(keyScheduleCtx, v.keyScheduleCtx) {
44		t.Errorf("labeledExtract: got %x, want %x", keyScheduleCtx, v.keyScheduleCtx)
45	}
46
47	secret := kdf.labeledExtract(v.sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID)
48	if !bytes.Equal(secret, v.secret) {
49		t.Errorf("labeledExtract: got %x, want %x", secret, v.secret)
50	}
51}
52
53func TestHKDFKDFLabeledExpand(t *testing.T) {
54	kdf, err := newKDF(hkdfSHA256)
55	if err != nil {
56		t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
57	}
58	id, v := internetDraftVector(t)
59	suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID)
60
61	tests := []struct {
62		infoLabel string
63		length    int
64		want      []byte
65		wantErr   bool
66	}{
67		{"key", 16, v.key, false},
68		{"base_nonce", 12, v.baseNonce, false},
69		{"large_length", int(math.Pow(2, 16)), []byte{}, true},
70	}
71
72	for _, test := range tests {
73		t.Run(test.infoLabel, func(t *testing.T) {
74			got, err := kdf.labeledExpand(v.secret, v.keyScheduleCtx, test.infoLabel, suiteID, test.length)
75			if test.wantErr {
76				if err == nil {
77					t.Error("labeledExpand: got success, want err")
78				}
79				return
80			}
81
82			if err != nil {
83				t.Errorf("labeledExpand: got err %q, want success", err)
84			}
85			if !bytes.Equal(got, test.want) {
86				t.Errorf("labeledExpand: got %x, want %x", got, test.want)
87			}
88		})
89	}
90}
91
92func TestHKDFKDFLabeledExpandRFCVectors(t *testing.T) {
93	kdf, err := newKDF(hkdfSHA256)
94	if err != nil {
95		t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
96	}
97	suiteID := hpkeSuiteID(x25519HKDFSHA256, hkdfSHA256, aes128GCM)
98
99	// Vectors are defined at
100	// https://datatracker.ietf.org/doc/html/rfc5869#appendix-A.
101	var tests = []struct {
102		name   string
103		info   string
104		prk    string
105		length int
106		want   string // Generated manually.
107	}{
108		{
109			name:   "basic",
110			info:   "f0f1f2f3f4f5f6f7f8f9",
111			prk:    "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5",
112			length: 42,
113			want:   "2f1a8eb86971cd1850d04a1b98f9a63d52d56c5a4d5fcb68103e57c7a85a1df2c9be1346ae041007712d",
114		},
115		{
116			name:   "longer inputs",
117			info:   "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
118			prk:    "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244",
119			length: 82,
120			want:   "3961afd1985cb4d811e261b3568c44b88ae7e5d5909d33a5419e954eb245fe03fd3635769d88cec8adb709e900fa399e1a68bdb9d5c879e385845eeb99034fd232e30d1acc58f7fa37791fe0c433221b1fec",
121		},
122		{
123			name:   "zero-length info",
124			info:   "",
125			prk:    "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
126			length: 42,
127			want:   "bdb2761a4f8504177b10ecc354f41153a3964435b9072d1f349c2993afbaa77a05ed426c384e195dba76",
128		},
129	}
130
131	for _, test := range tests {
132		t.Run(test.name, func(t *testing.T) {
133			info, err := hex.DecodeString(test.info)
134			if err != nil {
135				t.Fatal("hex.DecodeString(info) failed")
136			}
137			prk, err := hex.DecodeString(test.prk)
138			if err != nil {
139				t.Fatal("hex.DecodeString(prk) failed")
140			}
141			want, err := hex.DecodeString(test.want)
142			if err != nil {
143				t.Fatal("hex.DecodeString(want) failed")
144			}
145			got, err := kdf.labeledExpand(prk, info, "info_label", suiteID, test.length)
146			if err != nil {
147				t.Errorf("labeledExpand: got err %q, want success", err)
148			}
149			if !bytes.Equal(got, want) {
150				t.Errorf("labeledExpand: got %x, want %x", got, want)
151			}
152		})
153	}
154}
155
156func TestHKDFKDFExtractAndExpand(t *testing.T) {
157	kdf, err := newKDF(hkdfSHA256)
158	if err != nil {
159		t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
160	}
161	_, v := internetDraftVector(t)
162
163	dhSharedSecret, err := subtle.ComputeSharedSecretX25519(v.senderPrivKey, v.recipientPubKey)
164	if err != nil {
165		t.Fatalf("ComputeSharedSecretX25519: got err %q, want success", err)
166	}
167	kemCtx := []byte{}
168	kemCtx = append(kemCtx, v.senderPubKey...)
169	kemCtx = append(kemCtx, v.recipientPubKey...)
170
171	var tests = []struct {
172		length  int
173		want    []byte
174		wantErr bool
175	}{
176		{32, v.sharedSecret, false},
177		{int(math.Pow(2, 16)), nil, true},
178	}
179
180	for _, test := range tests {
181		t.Run(fmt.Sprintf("%d", test.length), func(t *testing.T) {
182			sharedSecret, err := kdf.extractAndExpand(
183				emptySalt,
184				dhSharedSecret,
185				"eae_prk",
186				kemCtx,
187				"shared_secret",
188				kemSuiteID(x25519HKDFSHA256),
189				test.length)
190			if test.wantErr {
191				if err == nil {
192					t.Error("extractAndExpand: got success, want err")
193				}
194				return
195			}
196
197			if err != nil {
198				t.Errorf("extractAndExpand: got err %q, want success", err)
199			}
200			if !bytes.Equal(sharedSecret, v.sharedSecret) {
201				t.Errorf("extractAndExpand: got %x, want %x", sharedSecret, v.sharedSecret)
202			}
203		})
204	}
205}
206