xref: /aosp_15_r20/external/tink/go/hybrid/internal/hpke/primitive_factory_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hpke
18
19import (
20	"testing"
21
22	pb "github.com/google/tink/go/proto/hpke_go_proto"
23)
24
25func TestNewKEM(t *testing.T) {
26	kemID, err := kemIDFromProto(pb.HpkeKem_DHKEM_X25519_HKDF_SHA256)
27	if err != nil {
28		t.Fatal(err)
29	}
30	if kemID != x25519HKDFSHA256 {
31		t.Errorf("kemID: got %d, want %d", kemID, x25519HKDFSHA256)
32	}
33
34	kem, err := newKEM(kemID)
35	if err != nil {
36		t.Fatal(err)
37	}
38	if kem.id() != x25519HKDFSHA256 {
39		t.Errorf("id: got %d, want %d", kem.id(), x25519HKDFSHA256)
40	}
41}
42
43func TestNewKEMUnsupportedID(t *testing.T) {
44	if _, err := newKEM(0x0010 /*= DHKEM(P-256, HKDF-SHA256)*/); err == nil {
45		t.Fatal("newKEM(unsupported ID): got success, want err")
46	}
47}
48
49func TestKEMIDFromProtoUnsupportedID(t *testing.T) {
50	if _, err := kemIDFromProto(pb.HpkeKem_KEM_UNKNOWN); err == nil {
51		t.Fatal("kemIDFromProto(unsupported ID): got success, want err")
52	}
53}
54
55func TestNewKDF(t *testing.T) {
56	kdfID, err := kdfIDFromProto(pb.HpkeKdf_HKDF_SHA256)
57	if err != nil {
58		t.Fatal(err)
59	}
60	if kdfID != hkdfSHA256 {
61		t.Errorf("kdfID: got %d, want %d", kdfID, hkdfSHA256)
62	}
63
64	kdf, err := newKDF(kdfID)
65	if err != nil {
66		t.Fatal(err)
67	}
68	if kdf.id() != hkdfSHA256 {
69		t.Errorf("id: got %d, want %d", kdf.id(), hkdfSHA256)
70	}
71}
72
73func TestNewKDFUnsupportedID(t *testing.T) {
74	if _, err := newKDF(0x0002 /*= HKDF-SHA384*/); err == nil {
75		t.Fatal("newKDF(unsupported ID): got success, want err")
76	}
77}
78
79func TestKDFIDFromProtoUnsupportedID(t *testing.T) {
80	if _, err := kdfIDFromProto(pb.HpkeKdf_KDF_UNKNOWN); err == nil {
81		t.Fatal("kdfIDFromProto(unsupported ID): got success, want err")
82	}
83}
84
85var aeads = []struct {
86	name  string
87	proto pb.HpkeAead
88	id    uint16
89}{
90	{"AES-128-GCM", pb.HpkeAead_AES_128_GCM, aes128GCM},
91	{"AES-256-GCM", pb.HpkeAead_AES_256_GCM, aes256GCM},
92	{"ChaCha20Poly1305", pb.HpkeAead_CHACHA20_POLY1305, chaCha20Poly1305},
93}
94
95func TestNewAEAD(t *testing.T) {
96	for _, a := range aeads {
97		t.Run(a.name, func(t *testing.T) {
98			aeadID, err := aeadIDFromProto(a.proto)
99			if err != nil {
100				t.Fatal(err)
101			}
102			if aeadID != a.id {
103				t.Errorf("aeadID: got %d, want %d", aeadID, a.id)
104			}
105
106			aead, err := newAEAD(aeadID)
107			if err != nil {
108				t.Fatal(err)
109			}
110			if aead.id() != a.id {
111				t.Errorf("id: got %d, want %d", aead.id(), a.id)
112			}
113		})
114	}
115}
116
117func TestNewAEADUnsupportedID(t *testing.T) {
118	if _, err := newAEAD(0xFFFF /*= Export-only*/); err == nil {
119		t.Fatal("newAEAD(unsupported ID): got success, want err")
120	}
121}
122
123func TestAEADIDFromProtoUnsupportedID(t *testing.T) {
124	if _, err := aeadIDFromProto(pb.HpkeAead_AEAD_UNKNOWN); err == nil {
125		t.Fatal("aeadIDFromProto(unsupported ID): got success, want err")
126	}
127}
128
129func TestNewPrimitivesFromProto(t *testing.T) {
130	for _, a := range aeads {
131		t.Run("", func(t *testing.T) {
132			params := &pb.HpkeParams{
133				Kem:  pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
134				Kdf:  pb.HpkeKdf_HKDF_SHA256,
135				Aead: a.proto,
136			}
137			kem, kdf, aead, err := newPrimitivesFromProto(params)
138			if err != nil {
139				t.Fatalf("newPrimitivesFromProto: %v", err)
140			}
141
142			if kem.id() != x25519HKDFSHA256 {
143				t.Errorf("kem.id: got %d, want %d", kem.id(), x25519HKDFSHA256)
144			}
145			if kdf.id() != hkdfSHA256 {
146				t.Errorf("kdf.id: got %d, want %d", kdf.id(), hkdfSHA256)
147			}
148			if aead.id() != a.id {
149				t.Errorf("aead.id: got %d, want %d", aead.id(), a.id)
150			}
151		})
152	}
153}
154
155func TestNewPrimitivesFromProtoUnsupportedID(t *testing.T) {
156	tests := []struct {
157		name   string
158		params *pb.HpkeParams
159	}{
160		{
161			"KEM",
162			&pb.HpkeParams{
163				Kem:  pb.HpkeKem_KEM_UNKNOWN,
164				Kdf:  pb.HpkeKdf_HKDF_SHA256,
165				Aead: pb.HpkeAead_AES_256_GCM,
166			},
167		},
168		{"KDF",
169			&pb.HpkeParams{
170				Kem:  pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
171				Kdf:  pb.HpkeKdf_KDF_UNKNOWN,
172				Aead: pb.HpkeAead_AES_256_GCM,
173			},
174		},
175		{"AEAD",
176			&pb.HpkeParams{
177				Kem:  pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
178				Kdf:  pb.HpkeKdf_HKDF_SHA256,
179				Aead: pb.HpkeAead_AEAD_UNKNOWN,
180			},
181		},
182	}
183
184	for _, test := range tests {
185		t.Run(test.name, func(t *testing.T) {
186			if _, _, _, err := newPrimitivesFromProto(test.params); err == nil {
187				t.Error("newPrimitivesFromProto: got success, want err")
188			}
189		})
190	}
191}
192