xref: /aosp_15_r20/external/tink/go/hybrid/internal/hpke/context.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hpke
18
19import (
20	"errors"
21	"fmt"
22	"math/big"
23
24	pb "github.com/google/tink/go/proto/hpke_go_proto"
25)
26
27type context struct {
28	aead              aead
29	maxSequenceNumber *big.Int
30	sequenceNumber    *big.Int
31	key               []byte
32	baseNonce         []byte
33	encapsulatedKey   []byte
34}
35
36// newSenderContext creates the HPKE sender context as per KeySchedule()
37// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10.
38func newSenderContext(recipientPubKey *pb.HpkePublicKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
39	if recipientPubKey.GetPublicKey() == nil {
40		return nil, errors.New("HpkePublicKey has an empty PublicKey")
41	}
42	sharedSecret, encapsulatedKey, err := kem.encapsulate(recipientPubKey.GetPublicKey())
43	if err != nil {
44		return nil, fmt.Errorf("encapsulate: %v", err)
45	}
46	return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info)
47}
48
49// newRecipientContext creates the HPKE recipient context as per KeySchedule()
50// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10.
51func newRecipientContext(encapsulatedKey []byte, recipientPrivKey *pb.HpkePrivateKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
52	if recipientPrivKey.GetPrivateKey() == nil {
53		return nil, errors.New("HpkePrivateKey has an empty PrivateKey")
54	}
55	sharedSecret, err := kem.decapsulate(encapsulatedKey, recipientPrivKey.GetPrivateKey())
56	if err != nil {
57		return nil, fmt.Errorf("decapsulate: %v", err)
58	}
59	return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info)
60}
61
62func createContext(encapsulatedKey []byte, sharedSecret []byte, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
63	suiteID := hpkeSuiteID(kem.id(), kdf.id(), aead.id())
64	// In base mode, both the pre-shared key (default_psk) and pre-shared key ID
65	// (default_psk_id) are empty strings, see
66	// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4.
67	pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID)
68	infoHash := kdf.labeledExtract(emptySalt, info, "info_hash", suiteID)
69	keyScheduleCtx := keyScheduleContext(baseMode, pskIDHash, infoHash)
70	secret := kdf.labeledExtract(sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID)
71
72	key, err := kdf.labeledExpand(secret, keyScheduleCtx, "key", suiteID, aead.keyLength())
73	if err != nil {
74		return nil, fmt.Errorf("labeledExpand of key: %v", err)
75	}
76	baseNonce, err := kdf.labeledExpand(secret, keyScheduleCtx, "base_nonce", suiteID, aead.nonceLength())
77	if err != nil {
78		return nil, fmt.Errorf("labeledExpand of base nonce: %v", err)
79	}
80
81	return &context{
82		aead:              aead,
83		maxSequenceNumber: maxSequenceNumber(aead.nonceLength()),
84		sequenceNumber:    big.NewInt(0),
85		key:               key,
86		baseNonce:         baseNonce,
87		encapsulatedKey:   encapsulatedKey,
88	}, nil
89}
90
91// maxSequenceNumber returns the maximum sequence number indicating that the
92// message limit is reached, calculated as per
93// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-11.
94func maxSequenceNumber(nonceLength int) *big.Int {
95	res := new(big.Int)
96	one := big.NewInt(1)
97	res.Lsh(one, uint(8*nonceLength)).Sub(res, one)
98	return res
99}
100
101func (c *context) incrementSequenceNumber() error {
102	c.sequenceNumber.Add(c.sequenceNumber, big.NewInt(1))
103	if c.sequenceNumber.Cmp(c.maxSequenceNumber) > 0 {
104		return errors.New("message limit reached")
105	}
106	return nil
107}
108
109// computeNonce computes the nonce as per
110// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-12.
111func (c *context) computeNonce() ([]byte, error) {
112	nonce := make([]byte, len(c.baseNonce))
113
114	// Write the big-endian c.sequenceNumber value at the end of nonce.
115	sequenceNumber := c.sequenceNumber.Bytes()
116	index := len(nonce) - len(sequenceNumber)
117	if index < 0 {
118		return nil, fmt.Errorf("sequence number length (%d) is larger than nonce length (%d)", len(sequenceNumber), len(nonce))
119	}
120	copy(nonce[index:], sequenceNumber)
121
122	// nonce XOR c.baseNonce.
123	for i, b := range c.baseNonce {
124		nonce[i] ^= b
125	}
126
127	return nonce, nil
128}
129
130// seal allows the sender's context to encrypt plaintext with associatedData,
131// defined as ContextS.Seal in
132// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-7.
133func (c *context) seal(plaintext, associatedData []byte) ([]byte, error) {
134	nonce, err := c.computeNonce()
135	if err != nil {
136		return nil, fmt.Errorf("computeNonce: %v", err)
137	}
138	ciphertext, err := c.aead.seal(c.key, nonce, plaintext, associatedData)
139	if err != nil {
140		return nil, fmt.Errorf("seal: %v", err)
141	}
142	if err := c.incrementSequenceNumber(); err != nil {
143		return nil, err
144	}
145	return ciphertext, nil
146}
147
148// open allows the receiver's context to decrypt ciphertext with
149// associatedData, defined as ContextR.Open in
150// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-9.
151func (c *context) open(ciphertext, associatedData []byte) ([]byte, error) {
152	nonce, err := c.computeNonce()
153	if err != nil {
154		return nil, fmt.Errorf("computeNonce: %v", err)
155	}
156	plaintext, err := c.aead.open(c.key, nonce, ciphertext, associatedData)
157	if err != nil {
158		return nil, fmt.Errorf("open: %v", err)
159	}
160	if err := c.incrementSequenceNumber(); err != nil {
161		return nil, err
162	}
163	return plaintext, nil
164}
165