xref: /aosp_15_r20/external/tink/go/internal/aead/aes_gcm_insecure_iv_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package aead_test
18
19import (
20	"bytes"
21	"fmt"
22	"math/rand"
23	"testing"
24
25	"github.com/google/tink/go/internal/aead"
26	"github.com/google/tink/go/subtle/random"
27	"github.com/google/tink/go/testutil"
28)
29
30var aesKeySizes = []uint32{
31	16, /*AES-128*/
32	32, /*AES-256*/
33}
34
35func TestAESGCMInsecureIVCiphertextSize(t *testing.T) {
36	for _, keySize := range aesKeySizes {
37		for _, prependIV := range []bool{true, false} {
38			t.Run(fmt.Sprintf("keySize-%d/prependIV-%t", keySize, prependIV), func(t *testing.T) {
39				key := random.GetRandomBytes(uint32(keySize))
40				a, err := aead.NewAESGCMInsecureIV(key, prependIV)
41				if err != nil {
42					t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
43				}
44				iv := random.GetRandomBytes(aead.AESGCMIVSize)
45				pt := random.GetRandomBytes(32)
46				ad := random.GetRandomBytes(32)
47
48				ct, err := a.Encrypt(iv, pt, ad)
49				if err != nil {
50					t.Fatalf("Encrypt: got err %q, want success", err)
51				}
52
53				wantSize := len(pt) + aead.AESGCMTagSize
54				if prependIV {
55					wantSize += aead.AESGCMIVSize
56				}
57				if len(ct) != wantSize {
58					t.Errorf("unexpected ciphertext length: got %d, want %d", len(ct), wantSize)
59				}
60			})
61		}
62	}
63}
64
65func TestAESGCMInsecureIVKeySize(t *testing.T) {
66	for _, keySize := range aesKeySizes {
67		for _, prependIV := range []bool{true, false} {
68			t.Run(fmt.Sprintf("keySize-%d/prependIV-%t", keySize, prependIV), func(t *testing.T) {
69				if _, err := aead.NewAESGCMInsecureIV(make([]byte, keySize), prependIV); err != nil {
70					t.Errorf("NewAESGCMInsecureIV: got err %q, want success", err)
71				}
72				if _, err := aead.NewAESGCMInsecureIV(make([]byte, keySize+1), prependIV); err == nil {
73					t.Error("NewAESGCMInsecureIV: got success, want err")
74				}
75				if _, err := aead.NewAESGCMInsecureIV(make([]byte, keySize-1), prependIV); err == nil {
76					t.Error("NewAESGCMInsecureIV: got success, want err")
77				}
78			})
79		}
80	}
81}
82
83func TestAESGCMInsecureIVMismatchedIV(t *testing.T) {
84	for _, keySize := range aesKeySizes {
85		t.Run(fmt.Sprintf("keySize-%d", keySize), func(t *testing.T) {
86			key := random.GetRandomBytes(uint32(keySize))
87			a, err := aead.NewAESGCMInsecureIV(key, true /*=prependIV*/)
88			if err != nil {
89				t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
90			}
91			iv := random.GetRandomBytes(aead.AESGCMIVSize)
92			pt := random.GetRandomBytes(32)
93			ad := random.GetRandomBytes(32)
94
95			ct, err := a.Encrypt(iv, pt, ad)
96			if err != nil {
97				t.Fatalf("Encrypt: got err %q, want success", err)
98			}
99
100			newIV := iv
101			randByte, randBit := rand.Intn(aead.AESGCMIVSize), rand.Intn(8)
102			newIV[randByte] ^= (1 << uint8(randBit))
103
104			if _, err := a.Decrypt(newIV, ct, ad); err == nil {
105				t.Error("Decrypt with wrong iv argument: want err, got success")
106			}
107			ctPrefixedWithNewIV := append(newIV, ct[aead.AESGCMIVSize:]...)
108			if _, err := a.Decrypt(iv, ctPrefixedWithNewIV, ad); err == nil {
109				t.Error("Decrypt with ct prefixed with wrong IV: want err, got success")
110			}
111		})
112	}
113}
114
115func TestAESGCMInsecureIV(t *testing.T) {
116	for _, keySize := range aesKeySizes {
117		for _, prependIV := range []bool{true, false} {
118			for ptSize := 0; ptSize < 75; ptSize++ {
119				t.Run(fmt.Sprintf("keySize-%d/prependIV-%t/ptSize-%d", keySize, prependIV, ptSize), func(t *testing.T) {
120					key := random.GetRandomBytes(uint32(keySize))
121					a, err := aead.NewAESGCMInsecureIV(key, prependIV)
122					if err != nil {
123						t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
124					}
125					iv := random.GetRandomBytes(aead.AESGCMIVSize)
126					pt := random.GetRandomBytes(uint32(ptSize))
127					ad := random.GetRandomBytes(uint32(5))
128
129					ct, err := a.Encrypt(iv, pt, ad)
130					if err != nil {
131						t.Fatalf("Encrypt: got err %q, want success", err)
132					}
133
134					got, err := a.Decrypt(iv, ct, ad)
135					if err != nil {
136						t.Fatalf("Decrypt: got err %q, want success", err)
137					}
138					if !bytes.Equal(got, pt) {
139						t.Errorf("Decrypt: got %x, want %x", got, pt)
140					}
141				})
142			}
143		}
144	}
145}
146
147func TestAESGCMInsecureIVLongPlaintext(t *testing.T) {
148	for _, keySize := range aesKeySizes {
149		for _, prependIV := range []bool{true, false} {
150			ptSize := 16
151			for ptSize <= 1<<24 {
152				t.Run(fmt.Sprintf("keySize-%d/prependIV-%t/ptSize-%d", keySize, prependIV, ptSize), func(t *testing.T) {
153					key := random.GetRandomBytes(uint32(keySize))
154					a, err := aead.NewAESGCMInsecureIV(key, prependIV)
155					if err != nil {
156						t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
157					}
158					iv := random.GetRandomBytes(aead.AESGCMIVSize)
159					pt := random.GetRandomBytes(uint32(ptSize))
160					ad := random.GetRandomBytes(uint32(ptSize / 3))
161
162					ct, err := a.Encrypt(iv, pt, ad)
163					if err != nil {
164						t.Fatalf("Encrypt: got err %q, want success", err)
165					}
166
167					got, err := a.Decrypt(iv, ct, ad)
168					if err != nil {
169						t.Fatalf("Decrypt: got err %q, want success", err)
170					}
171					if !bytes.Equal(got, pt) {
172						t.Errorf("Decrypt: got %x, want %x", got, pt)
173					}
174				})
175				ptSize += 5 * ptSize / 11
176			}
177		}
178	}
179}
180
181func TestAESGCMInsecureIVModifyCiphertext(t *testing.T) {
182	key := random.GetRandomBytes(16)
183	for _, prependIV := range []bool{true, false} {
184		t.Run(fmt.Sprintf("prependIV-%t", prependIV), func(t *testing.T) {
185			a, err := aead.NewAESGCMInsecureIV(key, prependIV)
186			if err != nil {
187				t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
188			}
189			iv := random.GetRandomBytes(aead.AESGCMIVSize)
190			pt := random.GetRandomBytes(32)
191			ad := random.GetRandomBytes(33)
192			ct, err := a.Encrypt(iv, pt, ad)
193			if err != nil {
194				t.Fatalf("Encrypt: got err %q, want success", err)
195			}
196
197			// Flip bits.
198			for i := 0; i < len(ct); i++ {
199				tmpCT := ct[i]
200				for j := 0; j < 8; j++ {
201					ct[i] ^= 1 << uint8(j)
202					tmpIV := iv
203					if prependIV {
204						tmpIV = ct[:aead.AESGCMIVSize]
205					}
206					if _, err := a.Decrypt(tmpIV, ct, ad); err == nil {
207						t.Errorf("ciphertext with flipped byte %d, bit %d: expected err, got success", i, j)
208					}
209					ct[i] = tmpCT
210				}
211			}
212
213			// Truncate ciphertext.
214			for i := 1; i < len(ct); i++ {
215				if _, err := a.Decrypt(iv, ct[:i], ad); err == nil {
216					t.Errorf("ciphertext truncated to byte %d: expected err, got success", i)
217				}
218			}
219
220			// Modify associated data.
221			for i := 0; i < len(ad); i++ {
222				tmp := ad[i]
223				for j := 0; j < 8; j++ {
224					ad[i] ^= 1 << uint8(j)
225					if _, err := a.Decrypt(iv, ct, ad); err == nil {
226						t.Errorf("associated data with flipped byte %d, bit %d: expected err, got success", i, j)
227					}
228					ad[i] = tmp
229				}
230			}
231		})
232	}
233}
234
235func TestAESGCMInsecureIVWycheproofVectors(t *testing.T) {
236	testutil.SkipTestIfTestSrcDirIsNotSet(t)
237
238	suite := new(AEADSuite)
239	if err := testutil.PopulateSuite(suite, "aes_gcm_test.json"); err != nil {
240		t.Fatalf("failed to populate suite: %s", err)
241	}
242	for _, group := range suite.TestGroups {
243		if err := aead.ValidateAESKeySize(group.KeySize / 8); err != nil {
244			continue
245		}
246		if group.IVSize != aead.AESGCMIVSize*8 {
247			continue
248		}
249		for _, tc := range group.Tests {
250			name := fmt.Sprintf("%s-%s(%d,%d):Case-%d", suite.Algorithm, group.Type, group.KeySize, group.TagSize, tc.CaseID)
251			t.Run(name, func(t *testing.T) {
252				a, err := aead.NewAESGCMInsecureIV(tc.Key, false /*=prependIV*/)
253				if err != nil {
254					t.Fatalf("NewAESGCMInsecureIV: got err %q, want success", err)
255				}
256
257				var combinedCT []byte
258				combinedCT = append(combinedCT, tc.CT...)
259				combinedCT = append(combinedCT, tc.Tag...)
260
261				got, err := a.Decrypt(tc.IV, combinedCT, tc.AD)
262				if err != nil {
263					if tc.Result == "valid" {
264						t.Errorf("Decrypt: got err %q, want success", err)
265					}
266				} else {
267					if tc.Result == "invalid" {
268						t.Error("Decrypt: got success, want error")
269					}
270					if !bytes.Equal(got, tc.Message) {
271						t.Errorf("Decrypt: got %x, want %x", got, tc.Message)
272					}
273				}
274			})
275		}
276	}
277}
278
279func TestPreallocatedCiphertextMemoryIsExact(t *testing.T) {
280	key := random.GetRandomBytes(16)
281	a, err := aead.NewAESGCMInsecureIV(key, true /*=prependIV*/)
282	if err != nil {
283		t.Fatalf("aead.NewAESGCMInsecureIV() err = %v, want nil", err)
284	}
285	iv := random.GetRandomBytes(aead.AESGCMIVSize)
286	plaintext := random.GetRandomBytes(13)
287	associatedData := random.GetRandomBytes(17)
288
289	ciphertext, err := a.Encrypt(iv, plaintext, associatedData)
290	if err != nil {
291		t.Fatalf("a.Encrypt() err = %v, want nil", err)
292	}
293  // Encrypt() uses cipher.Overhead() to pre-allocate the memory needed store the ciphertext.
294	// For AES GCM, the size of the allocated memory should always be exact. If this check fails, the
295	// pre-allocated memory was too large or too small. If it was too small, the system had to
296	// re-allocate more memory, which is expensive and should be avoided.
297	if len(ciphertext) != cap(ciphertext) {
298		t.Errorf("want len(ciphertext) == cap(ciphertext), got %d != %d", len(ciphertext), cap(ciphertext))
299	}
300}
301