xref: /aosp_15_r20/external/tink/go/hybrid/internal/hpke/context_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hpke
18
19import (
20	"bytes"
21	"math/big"
22	"testing"
23
24	"github.com/google/tink/go/subtle"
25	pb "github.com/google/tink/go/proto/hpke_go_proto"
26)
27
28// TODO(b/201070904): Write tests using baseModeX25519HKDFSHA256Vectors.
29func TestContextSender(t *testing.T) {
30	id, vec := internetDraftVector(t)
31	kem, err := newKEM(id.kemID)
32	if err != nil {
33		t.Fatalf("newKEM(%d): err %q", id.kemID, err)
34	}
35	x25519KEMGeneratePrivateKey = func() ([]byte, error) {
36		return vec.senderPrivKey, nil
37	}
38	kdf, err := newKDF(id.kdfID)
39	if err != nil {
40		t.Fatalf("newKDF(%d): err %q", id.kdfID, err)
41	}
42	aead, err := newAEAD(id.aeadID)
43	if err != nil {
44		t.Fatalf("newAEAD(%d): err %q", id.aeadID, err)
45	}
46
47	recipientPubKey := &pb.HpkePublicKey{PublicKey: vec.recipientPubKey}
48	senderCtx, err := newSenderContext(recipientPubKey, kem, kdf, aead, vec.info)
49	if err != nil {
50		t.Fatalf("newSenderContext: err %q", err)
51	}
52
53	for _, enc := range vec.consecutiveEncryptions {
54		if got, want := senderCtx.sequenceNumber, enc.sequenceNumber; got.Cmp(want) != 0 {
55			t.Fatalf("sequence number: got %s, want %s", got.String(), want.String())
56		}
57		ct, err := senderCtx.seal(enc.plaintext, enc.associatedData)
58		if err != nil {
59			t.Fatal(err)
60		}
61		if !bytes.Equal(ct, enc.ciphertext) {
62			t.Errorf("ciphertext: got %x, want %x", ct, enc.ciphertext)
63		}
64	}
65
66	for _, enc := range vec.otherEncryptions {
67		senderCtx.sequenceNumber.Set(enc.sequenceNumber)
68		ct, err := senderCtx.seal(enc.plaintext, enc.associatedData)
69		if err != nil {
70			t.Fatal(err)
71		}
72		if !bytes.Equal(ct, enc.ciphertext) {
73			t.Errorf("ciphertext: got %x, want %x", ct, enc.ciphertext)
74		}
75	}
76
77	x25519KEMGeneratePrivateKey = subtle.GeneratePrivateKeyX25519
78}
79
80func TestContextRecipient(t *testing.T) {
81	id, vec := internetDraftVector(t)
82	kem, err := newKEM(id.kemID)
83	if err != nil {
84		t.Fatalf("newKEM(%d): err %q", id.kemID, err)
85	}
86	kdf, err := newKDF(id.kdfID)
87	if err != nil {
88		t.Fatalf("newKDF(%d): err %q", id.kdfID, err)
89	}
90	aead, err := newAEAD(id.aeadID)
91	if err != nil {
92		t.Fatalf("newAEAD(%d): err %q", id.aeadID, err)
93	}
94
95	recipientPrivKey := &pb.HpkePrivateKey{PrivateKey: vec.recipientPrivKey}
96	recipientCtx, err := newRecipientContext(vec.encapsulatedKey, recipientPrivKey, kem, kdf, aead, vec.info)
97	if err != nil {
98		t.Fatalf("newRecipientContext: err %q", err)
99	}
100
101	for _, enc := range vec.consecutiveEncryptions {
102		if got, want := recipientCtx.sequenceNumber, enc.sequenceNumber; got.Cmp(want) != 0 {
103			t.Fatalf("sequence number: got %s, want %s", got.String(), want.String())
104		}
105		pt, err := recipientCtx.open(enc.ciphertext, enc.associatedData)
106		if err != nil {
107			t.Fatal(err)
108		}
109		if !bytes.Equal(pt, enc.plaintext) {
110			t.Errorf("plaintext: got %x, want %x", pt, enc.plaintext)
111		}
112	}
113
114	for _, enc := range vec.otherEncryptions {
115		recipientCtx.sequenceNumber.Set(enc.sequenceNumber)
116		pt, err := recipientCtx.open(enc.ciphertext, enc.associatedData)
117		if err != nil {
118			t.Fatal(err)
119		}
120		if !bytes.Equal(pt, enc.plaintext) {
121			t.Errorf("plaintext: got %x, want %x", pt, enc.plaintext)
122		}
123	}
124}
125
126func TestContextMaxSequenceNumber(t *testing.T) {
127	got := maxSequenceNumber(12 /*=AESGCMIVSize*/)
128	want, ok := new(big.Int).SetString("79228162514264337593543950335", 10) // (1 << (8*12)) - 1
129	if !ok {
130		t.Fatalf("SetString(\"79228162514264337593543950335\", 10): got err, want success")
131	}
132	if got.Cmp(want) != 0 {
133		t.Errorf("maxSequenceNumber(12): got %s, want %s", got.String(), want.String())
134	}
135}
136
137func TestComputeNonce(t *testing.T) {
138	id, vec := internetDraftVector(t)
139	kem, err := newKEM(id.kemID)
140	if err != nil {
141		t.Fatalf("newKEM(%d): err %q", id.kemID, err)
142	}
143	kdf, err := newKDF(id.kdfID)
144	if err != nil {
145		t.Fatalf("newKDF(%d): err %q", id.kdfID, err)
146	}
147	aead, err := newAEAD(id.aeadID)
148	if err != nil {
149		t.Fatalf("newAEAD(%d): err %q", id.aeadID, err)
150	}
151
152	recipientPrivKey := &pb.HpkePrivateKey{PrivateKey: vec.recipientPrivKey}
153	ctx, err := newRecipientContext(vec.encapsulatedKey, recipientPrivKey, kem, kdf, aead, vec.info)
154	if err != nil {
155		t.Fatalf("newRecipientContext: err %q", err)
156	}
157
158	if !bytes.Equal(ctx.baseNonce, vec.baseNonce) {
159		t.Fatalf("base nonce: got %x, want %x", ctx.baseNonce, vec.baseNonce)
160	}
161
162	for _, enc := range vec.consecutiveEncryptions {
163		nonce, err := ctx.computeNonce()
164		if err != nil {
165			t.Fatal(err)
166		}
167		if !bytes.Equal(nonce, enc.nonce) {
168			t.Errorf("computeNonce: got %x, want %x", nonce, enc.nonce)
169		}
170		if err := ctx.incrementSequenceNumber(); err != nil {
171			t.Fatal(err)
172		}
173	}
174
175	for _, enc := range vec.otherEncryptions {
176		ctx.sequenceNumber.Set(enc.sequenceNumber)
177		nonce, err := ctx.computeNonce()
178		if err != nil {
179			t.Fatal(err)
180		}
181		if !bytes.Equal(nonce, enc.nonce) {
182			t.Errorf("computeNonce: got %x, want %x", nonce, enc.nonce)
183		}
184		if err := ctx.incrementSequenceNumber(); err != nil {
185			t.Fatal(err)
186		}
187	}
188}
189