1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package subtle 18 19import ( 20 "bytes" 21 "crypto/elliptic" 22 "crypto/rand" 23 "errors" 24 "fmt" 25 "math/big" 26) 27 28// ECPublicKey represents a elliptic curve public key. 29type ECPublicKey struct { 30 elliptic.Curve 31 Point ECPoint 32} 33 34// ECPrivateKey represents a elliptic curve private key. 35type ECPrivateKey struct { 36 PublicKey ECPublicKey 37 D *big.Int 38} 39 40// GetECPrivateKey converts a stored private key to ECPrivateKey. 41func GetECPrivateKey(c elliptic.Curve, b []byte) *ECPrivateKey { 42 d := new(big.Int) 43 d.SetBytes(b) 44 45 x, y := c.Params().ScalarBaseMult(b) 46 pub := ECPublicKey{ 47 Curve: c, 48 Point: ECPoint{ 49 X: x, 50 Y: y, 51 }, 52 } 53 return &ECPrivateKey{ 54 PublicKey: pub, 55 D: d, 56 } 57 58} 59 60// ECPoint represents a point on the elliptic curve. 61type ECPoint struct { 62 X, Y *big.Int 63} 64 65func (p *ECPrivateKey) getParams() *elliptic.CurveParams { 66 return p.PublicKey.Curve.Params() 67} 68 69func getModulus(c elliptic.Curve) *big.Int { 70 return c.Params().P 71} 72 73func fieldSizeInBits(c elliptic.Curve) int { 74 t := big.NewInt(1) 75 r := t.Sub(getModulus(c), t) 76 return r.BitLen() 77} 78 79func fieldSizeInBytes(c elliptic.Curve) int { 80 return (fieldSizeInBits(c) + 7) / 8 81} 82 83func encodingSizeInBytes(c elliptic.Curve, p string) (int, error) { 84 cSize := fieldSizeInBytes(c) 85 switch p { 86 case "UNCOMPRESSED": 87 return 2*cSize + 1, nil 88 case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED": 89 return 2 * cSize, nil 90 case "COMPRESSED": 91 return cSize + 1, nil 92 } 93 return 0, fmt.Errorf("invalid point format :%s", p) 94 95} 96 97// PointEncode encodes a point into the format specified. 98func PointEncode(c elliptic.Curve, pFormat string, pt ECPoint) ([]byte, error) { 99 if !c.IsOnCurve(pt.X, pt.Y) { 100 return nil, errors.New("curve check failed") 101 } 102 cSize := fieldSizeInBytes(c) 103 y := pt.Y.Bytes() 104 x := pt.X.Bytes() 105 switch pFormat { 106 case "UNCOMPRESSED": 107 encoded := make([]byte, 2*cSize+1) 108 copy(encoded[1+2*cSize-len(y):], y) 109 copy(encoded[1+cSize-len(x):], x) 110 encoded[0] = 4 111 return encoded, nil 112 case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED": 113 encoded := make([]byte, 2*cSize) 114 if len(x) > cSize { 115 x = bytes.Replace(x, []byte("\x00"), []byte{}, -1) 116 } 117 if len(y) > cSize { 118 y = bytes.Replace(y, []byte("\x00"), []byte{}, -1) 119 } 120 copy(encoded[2*cSize-len(y):], y) 121 copy(encoded[cSize-len(x):], x) 122 return encoded, nil 123 case "COMPRESSED": 124 encoded := make([]byte, cSize+1) 125 copy(encoded[1+cSize-len(x):], x) 126 encoded[0] = 2 127 if pt.Y.Bit(0) > 0 { 128 encoded[0] = 3 129 } 130 return encoded, nil 131 } 132 return nil, errors.New("invalid point format") 133 134} 135 136// PointDecode decodes a encoded point to return an ECPoint 137func PointDecode(c elliptic.Curve, pFormat string, e []byte) (*ECPoint, error) { 138 cSize := fieldSizeInBytes(c) 139 x, y := new(big.Int), new(big.Int) 140 switch pFormat { 141 case "UNCOMPRESSED": 142 if len(e) != (2*cSize + 1) { 143 return nil, errors.New("invalid point size") 144 } 145 if e[0] != 4 { 146 return nil, errors.New("invalid point format") 147 } 148 x.SetBytes(e[1 : cSize+1]) 149 y.SetBytes(e[cSize+1:]) 150 if !c.IsOnCurve(x, y) { 151 return nil, errors.New("invalid point") 152 } 153 return &ECPoint{ 154 X: x, 155 Y: y, 156 }, nil 157 case "DO_NOT_USE_CRUNCHY_UNCOMPRESSED": 158 if len(e) != 2*cSize { 159 return nil, errors.New("invalid point size") 160 } 161 x.SetBytes(e[:cSize]) 162 y.SetBytes(e[cSize:]) 163 if !c.IsOnCurve(x, y) { 164 return nil, errors.New("invalid point") 165 } 166 return &ECPoint{ 167 X: x, 168 Y: y, 169 }, nil 170 case "COMPRESSED": 171 if len(e) != cSize+1 { 172 return nil, errors.New("compressed point has wrong length") 173 } 174 lsb := false 175 if e[0] == 2 { 176 lsb = false 177 } else if e[0] == 3 { 178 lsb = true 179 } else { 180 return nil, errors.New("invalid format") 181 } 182 x := new(big.Int) 183 x.SetBytes(e[1:]) 184 if (x.Sign() == -1) || (x.Cmp(c.Params().P) != -1) { 185 return nil, errors.New("x is out of range") 186 } 187 y := getY(x, lsb, c) 188 return &ECPoint{ 189 X: x, 190 Y: y, 191 }, nil 192 } 193 return nil, fmt.Errorf("invalid format: %s", pFormat) 194} 195 196func getY(x *big.Int, lsb bool, c elliptic.Curve) *big.Int { 197 // y² = x³ - 3x + b 198 x3 := new(big.Int).Mul(x, x) 199 x3.Mul(x3, x) 200 201 threeX := new(big.Int).Lsh(x, 1) 202 threeX.Add(threeX, x) 203 b := c.Params().B 204 p := c.Params().P 205 206 x3.Sub(x3, threeX) 207 x3.Add(x3, b) 208 x3.ModSqrt(x3, p) 209 e := uint(1) 210 if lsb { 211 e = 0 212 } 213 if e == x3.Bit(0) { 214 x3 := x3.Sub(p, x3) 215 x3.Mod(x3, p) 216 } 217 return x3 218} 219 220func validatePublicPoint(pub *ECPoint, priv *ECPrivateKey) error { 221 if priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) { 222 return nil 223 } 224 return errors.New("invalid public key") 225} 226 227// ComputeSharedSecret is used to compute a shared secret using given private key and peer public key. 228func ComputeSharedSecret(pub *ECPoint, priv *ECPrivateKey) ([]byte, error) { 229 if err := validatePublicPoint(pub, priv); err != nil { 230 return nil, err 231 } 232 233 x, y := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes()) 234 235 if x == nil { 236 return nil, errors.New("shared key compute error") 237 } 238 // check if x,y are on the curve 239 if err := validatePublicPoint(&ECPoint{X: x, Y: y}, priv); err != nil { 240 return nil, errors.New("invalid shared key") 241 } 242 243 sharedSecret := make([]byte, maxSharedKeyLength(priv.PublicKey)) 244 return x.FillBytes(sharedSecret), nil 245} 246 247func maxSharedKeyLength(pub ECPublicKey) int { 248 return (pub.Curve.Params().BitSize + 7) / 8 249} 250 251// GenerateECDHKeyPair will create a new private key for a given curve. 252func GenerateECDHKeyPair(c elliptic.Curve) (*ECPrivateKey, error) { 253 p, x, y, err := elliptic.GenerateKey(c, rand.Reader) 254 if err != nil { 255 return nil, err 256 } 257 return &ECPrivateKey{ 258 PublicKey: ECPublicKey{ 259 Curve: c, 260 Point: ECPoint{ 261 X: x, 262 Y: y, 263 }, 264 }, 265 D: new(big.Int).SetBytes(p), 266 }, nil 267 268} 269 270// GetCurve returns the elliptic.Curve for a given standard curve name. 271func GetCurve(c string) (elliptic.Curve, error) { 272 switch c { 273 case "secp224r1", "NIST_P224", "P-224": 274 return elliptic.P224(), nil 275 case "secp256r1", "NIST_P256", "P-256", "EllipticCurveType_NIST_P256": 276 return elliptic.P256(), nil 277 case "secp384r1", "NIST_P384", "P-384", "EllipticCurveType_NIST_P384": 278 return elliptic.P384(), nil 279 case "secp521r1", "NIST_P521", "P-521", "EllipticCurveType_NIST_P521": 280 return elliptic.P521(), nil 281 default: 282 return nil, errors.New("unsupported curve") 283 } 284} 285