1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package streamingprf_test
18
19import (
20	"strings"
21	"testing"
22
23	"google.golang.org/protobuf/proto"
24	"github.com/google/tink/go/keyderivation/internal/streamingprf"
25	"github.com/google/tink/go/subtle/random"
26	aesgcmpb "github.com/google/tink/go/proto/aes_gcm_go_proto"
27	commonpb "github.com/google/tink/go/proto/common_go_proto"
28	hkdfpb "github.com/google/tink/go/proto/hkdf_prf_go_proto"
29)
30
31func TestHKDFStreamingPRFKeyManagerPrimitive(t *testing.T) {
32	km := streamingprf.HKDFStreamingPRFKeyManager{}
33	for _, test := range []struct {
34		name string
35		hash commonpb.HashType
36		salt []byte
37	}{
38		{
39			name: "SHA256_nil_salt",
40			hash: commonpb.HashType_SHA256,
41		},
42		{
43			name: "SHA256_random_salt",
44			hash: commonpb.HashType_SHA256,
45			salt: random.GetRandomBytes(16),
46		},
47		{
48			name: "SHA512_nil_salt",
49			hash: commonpb.HashType_SHA512,
50		},
51		{
52			name: "SHA512_random_salt",
53			hash: commonpb.HashType_SHA512,
54			salt: random.GetRandomBytes(16),
55		},
56	} {
57		t.Run(test.name, func(t *testing.T) {
58			key := &hkdfpb.HkdfPrfKey{
59				Version: 0,
60				Params: &hkdfpb.HkdfPrfParams{
61					Hash: test.hash,
62					Salt: test.salt,
63				},
64				KeyValue: random.GetRandomBytes(32),
65			}
66			serializedKey, err := proto.Marshal(key)
67			if err != nil {
68				t.Fatalf("proto.Marshal(%v) err = %v, want nil", key, err)
69			}
70			p, err := km.Primitive(serializedKey)
71			if err != nil {
72				t.Fatalf("Primitive() err = %v, want nil", err)
73			}
74			prf, ok := p.(streamingprf.StreamingPRF)
75			if !ok {
76				t.Fatal("primitive is not StreamingPRF")
77			}
78			r, err := prf.Compute(random.GetRandomBytes(32))
79			if err != nil {
80				t.Fatalf("Compute() err = %v, want nil", err)
81			}
82			limit := limitFromHash(t, test.hash)
83			out := make([]byte, limit)
84			n, err := r.Read(out)
85			if n != limit || err != nil {
86				t.Errorf("Read() not enough bytes: %d, %v", n, err)
87			}
88		})
89	}
90}
91
92func TestHKDFStreamingPRFKeyManagerPrimitiveRejectsIncorrectKeys(t *testing.T) {
93	km := streamingprf.HKDFStreamingPRFKeyManager{}
94	missingParamsKey := &hkdfpb.HkdfPrfKey{
95		Version:  0,
96		KeyValue: random.GetRandomBytes(32),
97	}
98	serializedMissingParamsKey, err := proto.Marshal(missingParamsKey)
99	if err != nil {
100		t.Fatalf("proto.Marshal(%v) err = %v, want nil", serializedMissingParamsKey, err)
101	}
102	aesGCMKey := &aesgcmpb.AesGcmKey{Version: 0}
103	serializedAESGCMKey, err := proto.Marshal(aesGCMKey)
104	if err != nil {
105		t.Fatalf("proto.Marshal(%v) err = %v, want nil", aesGCMKey, err)
106	}
107	for _, test := range []struct {
108		name          string
109		serializedKey []byte
110	}{
111		{
112			name: "nil key",
113		},
114		{
115			name:          "zero-length key",
116			serializedKey: []byte{},
117		},
118		{
119			name:          "missing params",
120			serializedKey: serializedMissingParamsKey,
121		},
122		{
123			name:          "wrong key type",
124			serializedKey: serializedAESGCMKey,
125		},
126	} {
127		t.Run(test.name, func(t *testing.T) {
128			if _, err := km.Primitive(test.serializedKey); err == nil {
129				t.Error("Primitive() err = nil, want non-nil")
130			}
131		})
132	}
133}
134
135func TestHKDFStreamingPRFKeyManagerPrimitiveRejectsInvalidKeys(t *testing.T) {
136	km := streamingprf.HKDFStreamingPRFKeyManager{}
137
138	validKey := &hkdfpb.HkdfPrfKey{
139		Version: 0,
140		Params: &hkdfpb.HkdfPrfParams{
141			Hash: commonpb.HashType_SHA256,
142			Salt: random.GetRandomBytes(16),
143		},
144		KeyValue: random.GetRandomBytes(32),
145	}
146	serializedValidKey, err := proto.Marshal(validKey)
147	if err != nil {
148		t.Fatalf("proto.Marshal(%v) err = %v, want nil", validKey, err)
149	}
150	if _, err := km.Primitive(serializedValidKey); err != nil {
151		t.Errorf("Primitive() err = %v, want nil", err)
152	}
153
154	for _, test := range []struct {
155		name     string
156		version  uint32
157		hash     commonpb.HashType
158		keyValue []byte
159	}{
160		{
161			"invalid version",
162			100,
163			validKey.GetParams().GetHash(),
164			validKey.GetKeyValue(),
165		},
166		{
167			"invalid hash",
168			validKey.GetVersion(),
169			commonpb.HashType_SHA1,
170			validKey.GetKeyValue(),
171		},
172		{
173			"invalid key size",
174			validKey.GetVersion(),
175			validKey.GetParams().GetHash(),
176			random.GetRandomBytes(12),
177		},
178	} {
179		t.Run(test.name, func(t *testing.T) {
180			key := &hkdfpb.HkdfPrfKey{
181				Version: test.version,
182				Params: &hkdfpb.HkdfPrfParams{
183					Hash: test.hash,
184					// There is no concept of an invalid salt, as it can either be nil or
185					// have a value.
186					Salt: validKey.GetParams().GetSalt(),
187				},
188				KeyValue: test.keyValue,
189			}
190			serializedKey, err := proto.Marshal(key)
191			if err != nil {
192				t.Fatalf("proto.Marshal(%v) err = %v, want nil", key, err)
193			}
194			if _, err := km.Primitive(serializedKey); err == nil {
195				t.Error("Primitive() err = nil, want non-nil")
196			}
197		})
198	}
199}
200
201func TestHKDFStreamingPRFKeyManagerNewKeyAndNewKeyData(t *testing.T) {
202	km := streamingprf.HKDFStreamingPRFKeyManager{}
203	notImplemented := "not implemented"
204	if _, err := km.NewKey(random.GetRandomBytes(16)); !strings.Contains(err.Error(), notImplemented) {
205		t.Errorf("NewKey() err = %v, want containing %q", err, notImplemented)
206	}
207	if _, err := km.NewKeyData(random.GetRandomBytes(16)); !strings.Contains(err.Error(), notImplemented) {
208		t.Errorf("NewKey() err = %v, want containing %q", err, notImplemented)
209	}
210}
211
212func TestHKDFStreamingPRFKeyManagerDoesSupport(t *testing.T) {
213	km := streamingprf.HKDFStreamingPRFKeyManager{}
214	if !km.DoesSupport(hkdfPRFTypeURL) {
215		t.Errorf("DoesSupport(%q) = false, want true", hkdfPRFTypeURL)
216	}
217	if unsupported := "unsupported.key.type"; km.DoesSupport(unsupported) {
218		t.Errorf("DoesSupport(%q) = true, want false", unsupported)
219	}
220}
221
222func TestHKDFStreamingPRFKeyManagerTypeURL(t *testing.T) {
223	km := streamingprf.HKDFStreamingPRFKeyManager{}
224	if km.TypeURL() != hkdfPRFTypeURL {
225		t.Errorf("TypeURL() = %q, want %q", km.TypeURL(), hkdfPRFTypeURL)
226	}
227}
228