xref: /aosp_15_r20/external/tink/go/integration/hcvault/hcvault_aead.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2019 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hcvault
18
19import (
20	"encoding/base64"
21	"errors"
22	"net/url"
23	"strings"
24
25	"github.com/google/tink/go/tink"
26	"github.com/hashicorp/vault/api"
27)
28
29// vaultAEAD represents a HashiCorp Vault service to a particular URI.
30type vaultAEAD struct {
31	encKeyPath string
32	decKeyPath string
33	client     *api.Logical
34}
35
36var _ tink.AEAD = (*vaultAEAD)(nil)
37
38const (
39	encryptSegment = "encrypt"
40	decryptSegment = "decrypt"
41)
42
43// newHCVaultAEAD returns a new HashiCorp Vault service.
44func newHCVaultAEAD(keyURI string, client *api.Logical) (tink.AEAD, error) {
45	encKeyPath, decKeyPath, err := getEndpointPaths(keyURI)
46	if err != nil {
47		return nil, err
48	}
49	return &vaultAEAD{
50		encKeyPath: encKeyPath,
51		decKeyPath: decKeyPath,
52		client:     client,
53	}, nil
54}
55
56// Encrypt encrypts the plaintext data using a key stored in HashiCorp Vault.
57// associatedData parameter is used as a context for key derivation, more
58// information available https://www.vaultproject.io/docs/secrets/transit/index.html.
59func (a *vaultAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
60	// Create an encryption request map according to Vault REST API:
61	// https://www.vaultproject.io/api/secret/transit/index.html#encrypt-data.
62	req := map[string]interface{}{
63		"plaintext": base64.StdEncoding.EncodeToString(plaintext),
64		"context":   base64.StdEncoding.EncodeToString(associatedData),
65	}
66	secret, err := a.client.Write(a.encKeyPath, req)
67	if err != nil {
68		return nil, err
69	}
70	ciphertext := secret.Data["ciphertext"].(string)
71	return []byte(ciphertext), nil
72}
73
74// Decrypt decrypts the ciphertext using a key stored in HashiCorp Vault.
75// associatedData parameter is used as a context for key derivation, more
76// information available https://www.vaultproject.io/docs/secrets/transit/index.html.
77func (a *vaultAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
78	// Create a decryption request map according to Vault REST API:
79	// https://www.vaultproject.io/api/secret/transit/index.html#decrypt-data.
80	req := map[string]interface{}{
81		"ciphertext": string(ciphertext),
82		"context":    base64.StdEncoding.EncodeToString(associatedData),
83	}
84	secret, err := a.client.Write(a.decKeyPath, req)
85	if err != nil {
86		return nil, err
87	}
88	plaintext64 := secret.Data["plaintext"].(string)
89	plaintext, err := base64.StdEncoding.DecodeString(plaintext64)
90	if err != nil {
91		return nil, err
92	}
93	return plaintext, nil
94}
95
96// getEndpointPaths transforms keyURL into the Vault transit encrypt and decrypt
97// paths. The keyURL is expected to end in "/{mount}/keys/{keyName}". For
98// example, the keyURL "hcvault:///transit/keys/key-foo" will be transformed to
99// "transit/encrypt/key-foo" and "transit/decrypt/key-foo", and
100// "hcvault://my-vault.example.com/teams/billing/service/cipher/keys/key-bar"
101// will be transformed into
102// "hcvault://my-vault.example.com/teams/billing/service/cipher/encrypt/key-bar"
103// and
104// "hcvault://my-vault.example.com/teams/billing/service/cipher/decrypt/key-bar".
105func getEndpointPaths(keyURL string) (encryptPath, decryptPath string, err error) {
106	u, err := url.Parse(keyURL)
107	if err != nil || u.Scheme != "hcvault" {
108		return "", "", errors.New("malformed keyURL")
109	}
110
111	parts := strings.Split(u.EscapedPath(), "/")
112	length := len(parts)
113	if length < 4 || parts[length-2] != "keys" {
114		return "", "", errors.New("malformed keyURL")
115	}
116
117	parts[length-2] = encryptSegment
118	encryptPath = strings.Join(parts[1:], "/")
119	parts[length-2] = decryptSegment
120	decryptPath = strings.Join(parts[1:], "/")
121	return encryptPath, decryptPath, nil
122}
123