1// Copyright 2021 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package hpke 18 19import ( 20 "bytes" 21 "encoding/hex" 22 "fmt" 23 "math" 24 "testing" 25 26 "github.com/google/tink/go/subtle" 27) 28 29// TODO(b/201070904): Write tests using baseModeX25519HKDFSHA256Vectors. 30func TestHKDFKDFLabeledExtract(t *testing.T) { 31 kdf, err := newKDF(hkdfSHA256) 32 if err != nil { 33 t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err) 34 } 35 id, v := internetDraftVector(t) 36 suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID) 37 38 // Base mode uses a default empty value for the pre-shared key (PSK), see 39 // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4. 40 pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID) 41 infoHash := kdf.labeledExtract(emptySalt, v.info, "info_hash", suiteID) 42 keyScheduleCtx := keyScheduleContext(id.mode, pskIDHash, infoHash) 43 if !bytes.Equal(keyScheduleCtx, v.keyScheduleCtx) { 44 t.Errorf("labeledExtract: got %x, want %x", keyScheduleCtx, v.keyScheduleCtx) 45 } 46 47 secret := kdf.labeledExtract(v.sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID) 48 if !bytes.Equal(secret, v.secret) { 49 t.Errorf("labeledExtract: got %x, want %x", secret, v.secret) 50 } 51} 52 53func TestHKDFKDFLabeledExpand(t *testing.T) { 54 kdf, err := newKDF(hkdfSHA256) 55 if err != nil { 56 t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err) 57 } 58 id, v := internetDraftVector(t) 59 suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID) 60 61 tests := []struct { 62 infoLabel string 63 length int 64 want []byte 65 wantErr bool 66 }{ 67 {"key", 16, v.key, false}, 68 {"base_nonce", 12, v.baseNonce, false}, 69 {"large_length", int(math.Pow(2, 16)), []byte{}, true}, 70 } 71 72 for _, test := range tests { 73 t.Run(test.infoLabel, func(t *testing.T) { 74 got, err := kdf.labeledExpand(v.secret, v.keyScheduleCtx, test.infoLabel, suiteID, test.length) 75 if test.wantErr { 76 if err == nil { 77 t.Error("labeledExpand: got success, want err") 78 } 79 return 80 } 81 82 if err != nil { 83 t.Errorf("labeledExpand: got err %q, want success", err) 84 } 85 if !bytes.Equal(got, test.want) { 86 t.Errorf("labeledExpand: got %x, want %x", got, test.want) 87 } 88 }) 89 } 90} 91 92func TestHKDFKDFLabeledExpandRFCVectors(t *testing.T) { 93 kdf, err := newKDF(hkdfSHA256) 94 if err != nil { 95 t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err) 96 } 97 suiteID := hpkeSuiteID(x25519HKDFSHA256, hkdfSHA256, aes128GCM) 98 99 // Vectors are defined at 100 // https://datatracker.ietf.org/doc/html/rfc5869#appendix-A. 101 var tests = []struct { 102 name string 103 info string 104 prk string 105 length int 106 want string // Generated manually. 107 }{ 108 { 109 name: "basic", 110 info: "f0f1f2f3f4f5f6f7f8f9", 111 prk: "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5", 112 length: 42, 113 want: "2f1a8eb86971cd1850d04a1b98f9a63d52d56c5a4d5fcb68103e57c7a85a1df2c9be1346ae041007712d", 114 }, 115 { 116 name: "longer inputs", 117 info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff", 118 prk: "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244", 119 length: 82, 120 want: "3961afd1985cb4d811e261b3568c44b88ae7e5d5909d33a5419e954eb245fe03fd3635769d88cec8adb709e900fa399e1a68bdb9d5c879e385845eeb99034fd232e30d1acc58f7fa37791fe0c433221b1fec", 121 }, 122 { 123 name: "zero-length info", 124 info: "", 125 prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04", 126 length: 42, 127 want: "bdb2761a4f8504177b10ecc354f41153a3964435b9072d1f349c2993afbaa77a05ed426c384e195dba76", 128 }, 129 } 130 131 for _, test := range tests { 132 t.Run(test.name, func(t *testing.T) { 133 info, err := hex.DecodeString(test.info) 134 if err != nil { 135 t.Fatal("hex.DecodeString(info) failed") 136 } 137 prk, err := hex.DecodeString(test.prk) 138 if err != nil { 139 t.Fatal("hex.DecodeString(prk) failed") 140 } 141 want, err := hex.DecodeString(test.want) 142 if err != nil { 143 t.Fatal("hex.DecodeString(want) failed") 144 } 145 got, err := kdf.labeledExpand(prk, info, "info_label", suiteID, test.length) 146 if err != nil { 147 t.Errorf("labeledExpand: got err %q, want success", err) 148 } 149 if !bytes.Equal(got, want) { 150 t.Errorf("labeledExpand: got %x, want %x", got, want) 151 } 152 }) 153 } 154} 155 156func TestHKDFKDFExtractAndExpand(t *testing.T) { 157 kdf, err := newKDF(hkdfSHA256) 158 if err != nil { 159 t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err) 160 } 161 _, v := internetDraftVector(t) 162 163 dhSharedSecret, err := subtle.ComputeSharedSecretX25519(v.senderPrivKey, v.recipientPubKey) 164 if err != nil { 165 t.Fatalf("ComputeSharedSecretX25519: got err %q, want success", err) 166 } 167 kemCtx := []byte{} 168 kemCtx = append(kemCtx, v.senderPubKey...) 169 kemCtx = append(kemCtx, v.recipientPubKey...) 170 171 var tests = []struct { 172 length int 173 want []byte 174 wantErr bool 175 }{ 176 {32, v.sharedSecret, false}, 177 {int(math.Pow(2, 16)), nil, true}, 178 } 179 180 for _, test := range tests { 181 t.Run(fmt.Sprintf("%d", test.length), func(t *testing.T) { 182 sharedSecret, err := kdf.extractAndExpand( 183 emptySalt, 184 dhSharedSecret, 185 "eae_prk", 186 kemCtx, 187 "shared_secret", 188 kemSuiteID(x25519HKDFSHA256), 189 test.length) 190 if test.wantErr { 191 if err == nil { 192 t.Error("extractAndExpand: got success, want err") 193 } 194 return 195 } 196 197 if err != nil { 198 t.Errorf("extractAndExpand: got err %q, want success", err) 199 } 200 if !bytes.Equal(sharedSecret, v.sharedSecret) { 201 t.Errorf("extractAndExpand: got %x, want %x", sharedSecret, v.sharedSecret) 202 } 203 }) 204 } 205} 206