1// Copyright 2019 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package hcvault 18 19import ( 20 "encoding/base64" 21 "errors" 22 "net/url" 23 "strings" 24 25 "github.com/google/tink/go/tink" 26 "github.com/hashicorp/vault/api" 27) 28 29// vaultAEAD represents a HashiCorp Vault service to a particular URI. 30type vaultAEAD struct { 31 encKeyPath string 32 decKeyPath string 33 client *api.Logical 34} 35 36var _ tink.AEAD = (*vaultAEAD)(nil) 37 38const ( 39 encryptSegment = "encrypt" 40 decryptSegment = "decrypt" 41) 42 43// newHCVaultAEAD returns a new HashiCorp Vault service. 44func newHCVaultAEAD(keyURI string, client *api.Logical) (tink.AEAD, error) { 45 encKeyPath, decKeyPath, err := getEndpointPaths(keyURI) 46 if err != nil { 47 return nil, err 48 } 49 return &vaultAEAD{ 50 encKeyPath: encKeyPath, 51 decKeyPath: decKeyPath, 52 client: client, 53 }, nil 54} 55 56// Encrypt encrypts the plaintext data using a key stored in HashiCorp Vault. 57// associatedData parameter is used as a context for key derivation, more 58// information available https://www.vaultproject.io/docs/secrets/transit/index.html. 59func (a *vaultAEAD) Encrypt(plaintext, associatedData []byte) ([]byte, error) { 60 // Create an encryption request map according to Vault REST API: 61 // https://www.vaultproject.io/api/secret/transit/index.html#encrypt-data. 62 req := map[string]interface{}{ 63 "plaintext": base64.StdEncoding.EncodeToString(plaintext), 64 "context": base64.StdEncoding.EncodeToString(associatedData), 65 } 66 secret, err := a.client.Write(a.encKeyPath, req) 67 if err != nil { 68 return nil, err 69 } 70 ciphertext := secret.Data["ciphertext"].(string) 71 return []byte(ciphertext), nil 72} 73 74// Decrypt decrypts the ciphertext using a key stored in HashiCorp Vault. 75// associatedData parameter is used as a context for key derivation, more 76// information available https://www.vaultproject.io/docs/secrets/transit/index.html. 77func (a *vaultAEAD) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { 78 // Create a decryption request map according to Vault REST API: 79 // https://www.vaultproject.io/api/secret/transit/index.html#decrypt-data. 80 req := map[string]interface{}{ 81 "ciphertext": string(ciphertext), 82 "context": base64.StdEncoding.EncodeToString(associatedData), 83 } 84 secret, err := a.client.Write(a.decKeyPath, req) 85 if err != nil { 86 return nil, err 87 } 88 plaintext64 := secret.Data["plaintext"].(string) 89 plaintext, err := base64.StdEncoding.DecodeString(plaintext64) 90 if err != nil { 91 return nil, err 92 } 93 return plaintext, nil 94} 95 96// getEndpointPaths transforms keyURL into the Vault transit encrypt and decrypt 97// paths. The keyURL is expected to end in "/{mount}/keys/{keyName}". For 98// example, the keyURL "hcvault:///transit/keys/key-foo" will be transformed to 99// "transit/encrypt/key-foo" and "transit/decrypt/key-foo", and 100// "hcvault://my-vault.example.com/teams/billing/service/cipher/keys/key-bar" 101// will be transformed into 102// "hcvault://my-vault.example.com/teams/billing/service/cipher/encrypt/key-bar" 103// and 104// "hcvault://my-vault.example.com/teams/billing/service/cipher/decrypt/key-bar". 105func getEndpointPaths(keyURL string) (encryptPath, decryptPath string, err error) { 106 u, err := url.Parse(keyURL) 107 if err != nil || u.Scheme != "hcvault" { 108 return "", "", errors.New("malformed keyURL") 109 } 110 111 parts := strings.Split(u.EscapedPath(), "/") 112 length := len(parts) 113 if length < 4 || parts[length-2] != "keys" { 114 return "", "", errors.New("malformed keyURL") 115 } 116 117 parts[length-2] = encryptSegment 118 encryptPath = strings.Join(parts[1:], "/") 119 parts[length-2] = decryptSegment 120 decryptPath = strings.Join(parts[1:], "/") 121 return encryptPath, decryptPath, nil 122} 123