// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package hpke import ( "errors" "fmt" "math/big" pb "github.com/google/tink/go/proto/hpke_go_proto" ) type context struct { aead aead maxSequenceNumber *big.Int sequenceNumber *big.Int key []byte baseNonce []byte encapsulatedKey []byte } // newSenderContext creates the HPKE sender context as per KeySchedule() // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10. func newSenderContext(recipientPubKey *pb.HpkePublicKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) { if recipientPubKey.GetPublicKey() == nil { return nil, errors.New("HpkePublicKey has an empty PublicKey") } sharedSecret, encapsulatedKey, err := kem.encapsulate(recipientPubKey.GetPublicKey()) if err != nil { return nil, fmt.Errorf("encapsulate: %v", err) } return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info) } // newRecipientContext creates the HPKE recipient context as per KeySchedule() // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10. func newRecipientContext(encapsulatedKey []byte, recipientPrivKey *pb.HpkePrivateKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) { if recipientPrivKey.GetPrivateKey() == nil { return nil, errors.New("HpkePrivateKey has an empty PrivateKey") } sharedSecret, err := kem.decapsulate(encapsulatedKey, recipientPrivKey.GetPrivateKey()) if err != nil { return nil, fmt.Errorf("decapsulate: %v", err) } return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info) } func createContext(encapsulatedKey []byte, sharedSecret []byte, kem kem, kdf kdf, aead aead, info []byte) (*context, error) { suiteID := hpkeSuiteID(kem.id(), kdf.id(), aead.id()) // In base mode, both the pre-shared key (default_psk) and pre-shared key ID // (default_psk_id) are empty strings, see // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4. pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID) infoHash := kdf.labeledExtract(emptySalt, info, "info_hash", suiteID) keyScheduleCtx := keyScheduleContext(baseMode, pskIDHash, infoHash) secret := kdf.labeledExtract(sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID) key, err := kdf.labeledExpand(secret, keyScheduleCtx, "key", suiteID, aead.keyLength()) if err != nil { return nil, fmt.Errorf("labeledExpand of key: %v", err) } baseNonce, err := kdf.labeledExpand(secret, keyScheduleCtx, "base_nonce", suiteID, aead.nonceLength()) if err != nil { return nil, fmt.Errorf("labeledExpand of base nonce: %v", err) } return &context{ aead: aead, maxSequenceNumber: maxSequenceNumber(aead.nonceLength()), sequenceNumber: big.NewInt(0), key: key, baseNonce: baseNonce, encapsulatedKey: encapsulatedKey, }, nil } // maxSequenceNumber returns the maximum sequence number indicating that the // message limit is reached, calculated as per // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-11. func maxSequenceNumber(nonceLength int) *big.Int { res := new(big.Int) one := big.NewInt(1) res.Lsh(one, uint(8*nonceLength)).Sub(res, one) return res } func (c *context) incrementSequenceNumber() error { c.sequenceNumber.Add(c.sequenceNumber, big.NewInt(1)) if c.sequenceNumber.Cmp(c.maxSequenceNumber) > 0 { return errors.New("message limit reached") } return nil } // computeNonce computes the nonce as per // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-12. func (c *context) computeNonce() ([]byte, error) { nonce := make([]byte, len(c.baseNonce)) // Write the big-endian c.sequenceNumber value at the end of nonce. sequenceNumber := c.sequenceNumber.Bytes() index := len(nonce) - len(sequenceNumber) if index < 0 { return nil, fmt.Errorf("sequence number length (%d) is larger than nonce length (%d)", len(sequenceNumber), len(nonce)) } copy(nonce[index:], sequenceNumber) // nonce XOR c.baseNonce. for i, b := range c.baseNonce { nonce[i] ^= b } return nonce, nil } // seal allows the sender's context to encrypt plaintext with associatedData, // defined as ContextS.Seal in // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-7. func (c *context) seal(plaintext, associatedData []byte) ([]byte, error) { nonce, err := c.computeNonce() if err != nil { return nil, fmt.Errorf("computeNonce: %v", err) } ciphertext, err := c.aead.seal(c.key, nonce, plaintext, associatedData) if err != nil { return nil, fmt.Errorf("seal: %v", err) } if err := c.incrementSequenceNumber(); err != nil { return nil, err } return ciphertext, nil } // open allows the receiver's context to decrypt ciphertext with // associatedData, defined as ContextR.Open in // https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-9. func (c *context) open(ciphertext, associatedData []byte) ([]byte, error) { nonce, err := c.computeNonce() if err != nil { return nil, fmt.Errorf("computeNonce: %v", err) } plaintext, err := c.aead.open(c.key, nonce, ciphertext, associatedData) if err != nil { return nil, fmt.Errorf("open: %v", err) } if err := c.incrementSequenceNumber(); err != nil { return nil, err } return plaintext, nil }