1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package streamingprf 18 19import ( 20 "bytes" 21 "encoding/hex" 22 "fmt" 23 "io" 24 "strings" 25 "testing" 26 27 "github.com/google/tink/go/subtle/random" 28 "github.com/google/tink/go/testutil" 29) 30 31func TestNewHKDFStreamingPRF(t *testing.T) { 32 for _, test := range []struct { 33 name string 34 hash string 35 salt []byte 36 }{ 37 { 38 name: "SHA256_nil_salt", 39 hash: "SHA256", 40 }, 41 { 42 name: "SHA256_random_salt", 43 hash: "SHA256", 44 salt: random.GetRandomBytes(16), 45 }, 46 { 47 name: "SHA512_nil_salt", 48 hash: "SHA512", 49 }, 50 { 51 name: "SHA512_random_salt", 52 hash: "SHA512", 53 salt: random.GetRandomBytes(16), 54 }, 55 } { 56 t.Run(test.name, func(t *testing.T) { 57 key := random.GetRandomBytes(32) 58 h, err := newHKDFStreamingPRF(test.hash, key, test.salt) 59 if err != nil { 60 t.Fatalf("newHKDFStreamingPRF() err = %v, want nil", err) 61 } 62 if !bytes.Equal(h.key, key) { 63 t.Errorf("key = %v, want %v", h.key, key) 64 } 65 if !bytes.Equal(h.salt, test.salt) { 66 t.Errorf("salt = %v, want %v", h.salt, test.salt) 67 } 68 }) 69 } 70} 71 72func TestNewHKDFStreamingPRFFails(t *testing.T) { 73 for _, test := range []struct { 74 hash string 75 keySize uint32 76 }{ 77 { 78 hash: "SHA256", 79 keySize: 16, 80 }, 81 { 82 hash: "SHA512", 83 keySize: 16}, 84 { 85 hash: "SHA1", 86 keySize: 20, 87 }, 88 } { 89 t.Run(test.hash, func(t *testing.T) { 90 if _, err := newHKDFStreamingPRF(test.hash, random.GetRandomBytes(test.keySize), nil); err == nil { 91 t.Error("newHKDFStreamingPRF() err = nil, want non-nil") 92 } 93 }) 94 } 95} 96 97func TestHKDFStreamingPRFWithRFCVector(t *testing.T) { 98 // This is the only vector that uses an accepted hash function and has key 99 // size >= minHKDFStreamingPRFKeySize. 100 // https://www.rfc-editor.org/rfc/rfc5869#appendix-A.2 101 vec := struct { 102 hash string 103 key string 104 salt string 105 info string 106 outLen int 107 okm string 108 }{ 109 hash: "SHA256", 110 key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f", 111 salt: "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeaf", 112 info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff", 113 outLen: 82, 114 okm: "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5c1f3434f1d87", 115 } 116 key, err := hex.DecodeString(vec.key) 117 if err != nil { 118 t.Fatalf("hex.DecodeString err = %v, want nil", err) 119 } 120 salt, err := hex.DecodeString(vec.salt) 121 if err != nil { 122 t.Fatalf("hex.DecodeString err = %v, want nil", err) 123 } 124 info, err := hex.DecodeString(vec.info) 125 if err != nil { 126 t.Fatalf("hex.DecodeString err = %v, want nil", err) 127 } 128 129 h, err := newHKDFStreamingPRF(vec.hash, key, salt) 130 if err != nil { 131 t.Fatalf("newHKDFStreamingPRF() err = %v, want nil", err) 132 } 133 r, err := h.Compute(info) 134 if err != nil { 135 t.Fatalf("Compute() err = %v, want nil", err) 136 } 137 out := make([]byte, vec.outLen) 138 if _, err := io.ReadAtLeast(r, out, len(out)); err != nil { 139 t.Fatalf("io.ReadAtLeast err = %v, want nil", err) 140 } 141 if hex.EncodeToString(out) != vec.okm { 142 t.Errorf("Compute() = %v, want %v", hex.EncodeToString(out), vec.okm) 143 } 144} 145 146func TestHKDFStreamingPRFWithWycheproof(t *testing.T) { 147 testutil.SkipTestIfTestSrcDirIsNotSet(t) 148 149 type hkdfCase struct { 150 testutil.WycheproofCase 151 IKM testutil.HexBytes `json:"ikm"` 152 Salt testutil.HexBytes `json:"salt"` 153 Info testutil.HexBytes `json:"info"` 154 Size uint32 `json:"size"` 155 OKM testutil.HexBytes `json:"okm"` 156 } 157 type hkdfGroup struct { 158 testutil.WycheproofGroup 159 KeySize uint32 `json:"keySize"` 160 Type string `json:"type"` 161 Tests []*hkdfCase `json:"tests"` 162 } 163 type hkdfSuite struct { 164 testutil.WycheproofSuite 165 TestGroups []*hkdfGroup `json:"testGroups"` 166 } 167 168 count := 0 169 for _, hash := range []string{"SHA256", "SHA512"} { 170 filename := fmt.Sprintf("hkdf_%s_test.json", strings.ToLower(hash)) 171 suite := new(hkdfSuite) 172 if err := testutil.PopulateSuite(suite, filename); err != nil { 173 t.Fatalf("testutil.PopulateSuite(%v, %s): %v", suite, filename, err) 174 } 175 for _, group := range suite.TestGroups { 176 for _, test := range group.Tests { 177 caseName := fmt.Sprintf("%s(%d):Case-%d", hash, group.KeySize, test.CaseID) 178 t.Run(caseName, func(t *testing.T) { 179 if got, want := len(test.IKM), int(group.KeySize/8); got != want { 180 t.Fatalf("invalid key length = %d, want %d", got, want) 181 } 182 count++ 183 184 h, err := newHKDFStreamingPRF(hash, test.IKM, test.Salt) 185 switch test.Result { 186 case "valid": 187 if len(test.IKM) < minHKDFStreamingPRFKeySize { 188 if err == nil { 189 t.Error("newHKDFStreamingPRF err = nil, want non-nil") 190 } 191 return 192 } 193 if err != nil { 194 t.Fatalf("newHKDFStreamingPRF err = %v, want nil", err) 195 } 196 r, err := h.Compute(test.Info) 197 if err != nil { 198 t.Fatalf("Compute() err = %v, want nil", err) 199 } 200 out := make([]byte, test.Size) 201 if _, err := io.ReadAtLeast(r, out, len(out)); err != nil { 202 t.Fatalf("io.ReadAtLeast err = %v, want nil", err) 203 } 204 if !bytes.Equal(out, test.OKM) { 205 t.Errorf("Compute() = %v, want %v", out, test.OKM) 206 } 207 208 case "invalid": 209 if err != nil { 210 return 211 } 212 r, err := h.Compute(test.Info) 213 if err != nil { 214 t.Fatalf("Compute() err = %v, want nil", err) 215 } 216 out := make([]byte, test.Size) 217 if _, err := io.ReadAtLeast(r, out, len(out)); err == nil { 218 t.Error("io.ReadAtLeast err = nil, want non-nil") 219 } 220 221 default: 222 t.Errorf("unsupported test result: %s", test.Result) 223 } 224 }) 225 } 226 } 227 } 228 if count < 200 { 229 t.Errorf("number of test cases = %d, want > 200", count) 230 } 231} 232