1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package protocmp provides protobuf specific options for the 6// "github.com/google/go-cmp/cmp" package. 7// 8// The primary feature is the Transform option, which transform proto.Message 9// types into a Message map that is suitable for cmp to introspect upon. 10// All other options in this package must be used in conjunction with Transform. 11package protocmp 12 13import ( 14 "reflect" 15 "strconv" 16 17 "github.com/google/go-cmp/cmp" 18 19 "google.golang.org/protobuf/encoding/protowire" 20 "google.golang.org/protobuf/internal/genid" 21 "google.golang.org/protobuf/internal/msgfmt" 22 "google.golang.org/protobuf/proto" 23 "google.golang.org/protobuf/reflect/protoreflect" 24 "google.golang.org/protobuf/reflect/protoregistry" 25 "google.golang.org/protobuf/runtime/protoiface" 26 "google.golang.org/protobuf/runtime/protoimpl" 27) 28 29var ( 30 enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() 31 messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem() 32 messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem() 33) 34 35// Enum is a dynamic representation of a protocol buffer enum that is 36// suitable for cmp.Equal and cmp.Diff to compare upon. 37type Enum struct { 38 num protoreflect.EnumNumber 39 ed protoreflect.EnumDescriptor 40} 41 42// Descriptor returns the enum descriptor. 43// It returns nil for a zero Enum value. 44func (e Enum) Descriptor() protoreflect.EnumDescriptor { 45 return e.ed 46} 47 48// Number returns the enum value as an integer. 49func (e Enum) Number() protoreflect.EnumNumber { 50 return e.num 51} 52 53// Equal reports whether e1 and e2 represent the same enum value. 54func (e1 Enum) Equal(e2 Enum) bool { 55 if e1.ed.FullName() != e2.ed.FullName() { 56 return false 57 } 58 return e1.num == e2.num 59} 60 61// String returns the name of the enum value if known (e.g., "ENUM_VALUE"), 62// otherwise it returns the formatted decimal enum number (e.g., "14"). 63func (e Enum) String() string { 64 if ev := e.ed.Values().ByNumber(e.num); ev != nil { 65 return string(ev.Name()) 66 } 67 return strconv.Itoa(int(e.num)) 68} 69 70const ( 71 // messageTypeKey indicates the protobuf message type. 72 // The value type is always messageMeta. 73 // From the public API, it presents itself as only the type, but the 74 // underlying data structure holds arbitrary metadata about the message. 75 messageTypeKey = "@type" 76 77 // messageInvalidKey indicates that the message is invalid. 78 // The value is always the boolean "true". 79 messageInvalidKey = "@invalid" 80) 81 82type messageMeta struct { 83 m proto.Message 84 md protoreflect.MessageDescriptor 85 xds map[string]protoreflect.ExtensionDescriptor 86} 87 88func (t messageMeta) String() string { 89 return string(t.md.FullName()) 90} 91 92func (t1 messageMeta) Equal(t2 messageMeta) bool { 93 return t1.md.FullName() == t2.md.FullName() 94} 95 96// Message is a dynamic representation of a protocol buffer message that is 97// suitable for cmp.Equal and cmp.Diff to directly operate upon. 98// 99// Every populated known field (excluding extension fields) is stored in the map 100// with the key being the short name of the field (e.g., "field_name") and 101// the value determined by the kind and cardinality of the field. 102// 103// Singular scalars are represented by the same Go type as protoreflect.Value, 104// singular messages are represented by the Message type, 105// singular enums are represented by the Enum type, 106// list fields are represented as a Go slice, and 107// map fields are represented as a Go map. 108// 109// Every populated extension field is stored in the map with the key being the 110// full name of the field surrounded by brackets (e.g., "[extension.full.name]") 111// and the value determined according to the same rules as known fields. 112// 113// Every unknown field is stored in the map with the key being the field number 114// encoded as a decimal string (e.g., "132") and the value being the raw bytes 115// of the encoded field (as the protoreflect.RawFields type). 116// 117// Message values must not be created by or mutated by users. 118type Message map[string]interface{} 119 120// Unwrap returns the original message value. 121// It returns nil if this Message was not constructed from another message. 122func (m Message) Unwrap() proto.Message { 123 mm, _ := m[messageTypeKey].(messageMeta) 124 return mm.m 125} 126 127// Descriptor return the message descriptor. 128// It returns nil for a zero Message value. 129func (m Message) Descriptor() protoreflect.MessageDescriptor { 130 mm, _ := m[messageTypeKey].(messageMeta) 131 return mm.md 132} 133 134// ProtoReflect returns a reflective view of m. 135// It only implements the read-only operations of protoreflect.Message. 136// Calling any mutating operations on m panics. 137func (m Message) ProtoReflect() protoreflect.Message { 138 return (reflectMessage)(m) 139} 140 141// ProtoMessage is a marker method from the legacy message interface. 142func (m Message) ProtoMessage() {} 143 144// Reset is the required Reset method from the legacy message interface. 145func (m Message) Reset() { 146 panic("invalid mutation of a read-only message") 147} 148 149// String returns a formatted string for the message. 150// It is intended for human debugging and has no guarantees about its 151// exact format or the stability of its output. 152func (m Message) String() string { 153 switch { 154 case m == nil: 155 return "<nil>" 156 case !m.ProtoReflect().IsValid(): 157 return "<invalid>" 158 default: 159 return msgfmt.Format(m) 160 } 161} 162 163type option struct{} 164 165// Transform returns a cmp.Option that converts each proto.Message to a Message. 166// The transformation does not mutate nor alias any converted messages. 167// 168// The google.protobuf.Any message is automatically unmarshaled such that the 169// "value" field is a Message representing the underlying message value 170// assuming it could be resolved and properly unmarshaled. 171// 172// This does not directly transform higher-order composite Go types. 173// For example, []*foopb.Message is not transformed into []Message, 174// but rather the individual message elements of the slice are transformed. 175// 176// Note that there are currently no custom options for Transform, 177// but the use of an unexported type keeps the future open. 178func Transform(...option) cmp.Option { 179 // addrType returns a pointer to t if t isn't a pointer or interface. 180 addrType := func(t reflect.Type) reflect.Type { 181 if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr { 182 return t 183 } 184 return reflect.PtrTo(t) 185 } 186 187 // TODO: Should this transform protoreflect.Enum types to Enum as well? 188 return cmp.FilterPath(func(p cmp.Path) bool { 189 ps := p.Last() 190 if isMessageType(addrType(ps.Type())) { 191 return true 192 } 193 194 // Check whether the concrete values of an interface both satisfy 195 // the Message interface. 196 if ps.Type().Kind() == reflect.Interface { 197 vx, vy := ps.Values() 198 if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { 199 return false 200 } 201 return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) 202 } 203 204 return false 205 }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message { 206 // For user convenience, shallow copy the message value if necessary 207 // in order for it to implement the message interface. 208 if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) { 209 pv := reflect.New(rv.Type()) 210 pv.Elem().Set(rv) 211 v = pv.Interface() 212 } 213 214 m := protoimpl.X.MessageOf(v) 215 switch { 216 case m == nil: 217 return nil 218 case !m.IsValid(): 219 return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true} 220 default: 221 return transformMessage(m) 222 } 223 })) 224} 225 226func isMessageType(t reflect.Type) bool { 227 // Avoid transforming the Message itself. 228 if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) { 229 return false 230 } 231 return t.Implements(messageV1Type) || t.Implements(messageV2Type) 232} 233 234func transformMessage(m protoreflect.Message) Message { 235 mx := Message{} 236 mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)} 237 238 // Handle known and extension fields. 239 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 240 s := fd.TextName() 241 if fd.IsExtension() { 242 mt.xds[s] = fd 243 } 244 switch { 245 case fd.IsList(): 246 mx[s] = transformList(fd, v.List()) 247 case fd.IsMap(): 248 mx[s] = transformMap(fd, v.Map()) 249 default: 250 mx[s] = transformSingular(fd, v) 251 } 252 return true 253 }) 254 255 // Handle unknown fields. 256 for b := m.GetUnknown(); len(b) > 0; { 257 num, _, n := protowire.ConsumeField(b) 258 s := strconv.Itoa(int(num)) 259 b2, _ := mx[s].(protoreflect.RawFields) 260 mx[s] = append(b2, b[:n]...) 261 b = b[n:] 262 } 263 264 // Expand Any messages. 265 if mt.md.FullName() == genid.Any_message_fullname { 266 // TODO: Expose Transform option to specify a custom resolver? 267 s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string) 268 b, _ := mx[string(genid.Any_Value_field_name)].([]byte) 269 mt, err := protoregistry.GlobalTypes.FindMessageByURL(s) 270 if mt != nil && err == nil { 271 m2 := mt.New() 272 err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface()) 273 if err == nil { 274 mx[string(genid.Any_Value_field_name)] = transformMessage(m2) 275 } 276 } 277 } 278 279 mx[messageTypeKey] = mt 280 return mx 281} 282 283func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} { 284 t := protoKindToGoType(fd.Kind()) 285 rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len()) 286 for i := 0; i < lv.Len(); i++ { 287 v := reflect.ValueOf(transformSingular(fd, lv.Get(i))) 288 rv.Index(i).Set(v) 289 } 290 return rv.Interface() 291} 292 293func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} { 294 kfd := fd.MapKey() 295 vfd := fd.MapValue() 296 kt := protoKindToGoType(kfd.Kind()) 297 vt := protoKindToGoType(vfd.Kind()) 298 rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len()) 299 mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { 300 kv := reflect.ValueOf(transformSingular(kfd, k.Value())) 301 vv := reflect.ValueOf(transformSingular(vfd, v)) 302 rv.SetMapIndex(kv, vv) 303 return true 304 }) 305 return rv.Interface() 306} 307 308func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} { 309 switch fd.Kind() { 310 case protoreflect.EnumKind: 311 return Enum{num: v.Enum(), ed: fd.Enum()} 312 case protoreflect.MessageKind, protoreflect.GroupKind: 313 return transformMessage(v.Message()) 314 case protoreflect.BytesKind: 315 // The protoreflect API does not specify whether an empty bytes is 316 // guaranteed to be nil or not. Always return non-nil bytes to avoid 317 // leaking information about the concrete proto.Message implementation. 318 if len(v.Bytes()) == 0 { 319 return []byte{} 320 } 321 return v.Bytes() 322 default: 323 return v.Interface() 324 } 325} 326 327func protoKindToGoType(k protoreflect.Kind) reflect.Type { 328 switch k { 329 case protoreflect.BoolKind: 330 return reflect.TypeOf(bool(false)) 331 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 332 return reflect.TypeOf(int32(0)) 333 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 334 return reflect.TypeOf(int64(0)) 335 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 336 return reflect.TypeOf(uint32(0)) 337 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 338 return reflect.TypeOf(uint64(0)) 339 case protoreflect.FloatKind: 340 return reflect.TypeOf(float32(0)) 341 case protoreflect.DoubleKind: 342 return reflect.TypeOf(float64(0)) 343 case protoreflect.StringKind: 344 return reflect.TypeOf(string("")) 345 case protoreflect.BytesKind: 346 return reflect.TypeOf([]byte(nil)) 347 case protoreflect.EnumKind: 348 return reflect.TypeOf(Enum{}) 349 case protoreflect.MessageKind, protoreflect.GroupKind: 350 return reflect.TypeOf(Message{}) 351 default: 352 panic("invalid kind") 353 } 354} 355