1// Copyright 2021 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package nullable 6 7import ( 8 "reflect" 9 "testing" 10 11 "github.com/google/go-cmp/cmp" 12 "google.golang.org/protobuf/proto" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 "google.golang.org/protobuf/runtime/protoimpl" 15 "google.golang.org/protobuf/testing/protocmp" 16) 17 18func Test(t *testing.T) { 19 for _, mt := range []protoreflect.MessageType{ 20 protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(), 21 protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(), 22 } { 23 t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) { 24 testEmptyMessage(t, mt.Zero(), false) 25 testEmptyMessage(t, mt.New(), true) 26 //testMethods(t, mt) 27 }) 28 } 29} 30 31var methodTestProtos = []protoreflect.MessageType{ 32 protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(), 33} 34 35func TestMethods(t *testing.T) { 36 for _, mt := range methodTestProtos { 37 t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) { 38 testMethods(t, mt) 39 }) 40 } 41} 42 43func testMethods(t *testing.T, mt protoreflect.MessageType) { 44 m1 := mt.New() 45 populated := testPopulateMessage(t, m1, 2) 46 b, err := proto.Marshal(m1.Interface()) 47 if err != nil { 48 t.Errorf("proto.Marshal error: %v", err) 49 } 50 if populated && len(b) == 0 { 51 t.Errorf("len(proto.Marshal) = 0, want >0") 52 } 53 m2 := mt.New() 54 if err := proto.Unmarshal(b, m2.Interface()); err != nil { 55 t.Errorf("proto.Unmarshal error: %v", err) 56 } 57 if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" { 58 t.Errorf("message mismatch:\n%v", diff) 59 } 60 proto.Reset(m2.Interface()) 61 testEmptyMessage(t, m2, true) 62 proto.Merge(m2.Interface(), m1.Interface()) 63 if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" { 64 t.Errorf("message mismatch:\n%v", diff) 65 } 66 proto.Merge(mt.New().Interface(), mt.Zero().Interface()) 67} 68 69func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) { 70 numFields := func(m protoreflect.Message) (n int) { 71 m.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool { 72 n++ 73 return true 74 }) 75 return n 76 } 77 78 md := m.Descriptor() 79 if gotValid := m.IsValid(); gotValid != wantValid { 80 t.Errorf("%v.IsValid = %v, want %v", md.FullName(), gotValid, wantValid) 81 } 82 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 83 t.Errorf("%v.Range iterated over field %v, want no iteration", md.FullName(), fd.Name()) 84 return true 85 }) 86 fds := md.Fields() 87 for i := 0; i < fds.Len(); i++ { 88 fd := fds.Get(i) 89 if m.Has(fd) { 90 t.Errorf("%v.Has(%v) = true, want false", md.FullName(), fd.Name()) 91 } 92 v := m.Get(fd) 93 switch { 94 case fd.IsList(): 95 if n := v.List().Len(); n > 0 { 96 t.Errorf("%v.Get(%v).List().Len() = %v, want 0", md.FullName(), fd.Name(), n) 97 } 98 ls := m.NewField(fd).List() 99 if fd.Message() != nil { 100 if n := numFields(ls.NewElement().Message()); n > 0 { 101 t.Errorf("%v.NewField(%v).List().NewElement().Message().Len() = %v, want 0", md.FullName(), fd.Name(), n) 102 } 103 } 104 case fd.IsMap(): 105 if n := v.Map().Len(); n > 0 { 106 t.Errorf("%v.Get(%v).Map().Len() = %v, want 0", md.FullName(), fd.Name(), n) 107 } 108 ms := m.NewField(fd).Map() 109 if fd.MapValue().Message() != nil { 110 if n := numFields(ms.NewValue().Message()); n > 0 { 111 t.Errorf("%v.NewField(%v).Map().NewValue().Message().Len() = %v, want 0", md.FullName(), fd.Name(), n) 112 } 113 } 114 case fd.Message() != nil: 115 if n := numFields(v.Message()); n > 0 { 116 t.Errorf("%v.Get(%v).Message().Len() = %v, want 0", md.FullName(), fd.Name(), n) 117 } 118 if n := numFields(m.NewField(fd).Message()); n > 0 { 119 t.Errorf("%v.NewField(%v).Message().Len() = %v, want 0", md.FullName(), fd.Name(), n) 120 } 121 default: 122 if !reflect.DeepEqual(v.Interface(), fd.Default().Interface()) { 123 t.Errorf("%v.Get(%v) = %v, want %v", md.FullName(), fd.Name(), v, fd.Default()) 124 } 125 m.NewField(fd) // should not panic 126 } 127 } 128 ods := md.Oneofs() 129 for i := 0; i < ods.Len(); i++ { 130 od := ods.Get(i) 131 if fd := m.WhichOneof(od); fd != nil { 132 t.Errorf("%v.WhichOneof(%v) = %v, want nil", md.FullName(), od.Name(), fd.Name()) 133 } 134 } 135 if b := m.GetUnknown(); b != nil { 136 t.Errorf("%v.GetUnknown() = %v, want nil", md.FullName(), b) 137 } 138} 139 140func testPopulateMessage(t *testing.T, m protoreflect.Message, depth int) bool { 141 if depth == 0 { 142 return false 143 } 144 md := m.Descriptor() 145 fds := md.Fields() 146 var populatedMessage bool 147 for i := 0; i < fds.Len(); i++ { 148 populatedField := true 149 fd := fds.Get(i) 150 m.Clear(fd) // should not panic 151 switch { 152 case fd.IsList(): 153 ls := m.Mutable(fd).List() 154 if fd.Message() == nil { 155 ls.Append(scalarValue(fd.Kind())) 156 } else { 157 populatedField = testPopulateMessage(t, ls.AppendMutable().Message(), depth-1) 158 } 159 case fd.IsMap(): 160 ms := m.Mutable(fd).Map() 161 if fd.MapValue().Message() == nil { 162 ms.Set( 163 scalarValue(fd.MapKey().Kind()).MapKey(), 164 scalarValue(fd.MapValue().Kind()), 165 ) 166 } else { 167 // NOTE: Map.Mutable does not work with non-nullable fields. 168 m2 := ms.NewValue().Message() 169 populatedField = testPopulateMessage(t, m2, depth-1) 170 ms.Set( 171 scalarValue(fd.MapKey().Kind()).MapKey(), 172 protoreflect.ValueOfMessage(m2), 173 ) 174 } 175 case fd.Message() != nil: 176 populatedField = testPopulateMessage(t, m.Mutable(fd).Message(), depth-1) 177 default: 178 m.Set(fd, scalarValue(fd.Kind())) 179 } 180 if populatedField && !m.Has(fd) { 181 t.Errorf("%v.Has(%v) = false, want true", md.FullName(), fd.Name()) 182 } 183 populatedMessage = populatedMessage || populatedField 184 } 185 m.SetUnknown(m.GetUnknown()) // should not panic 186 return populatedMessage 187} 188 189func scalarValue(k protoreflect.Kind) protoreflect.Value { 190 switch k { 191 case protoreflect.BoolKind: 192 return protoreflect.ValueOfBool(true) 193 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 194 return protoreflect.ValueOfInt32(-32) 195 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 196 return protoreflect.ValueOfInt64(-64) 197 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 198 return protoreflect.ValueOfUint32(32) 199 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 200 return protoreflect.ValueOfUint64(64) 201 case protoreflect.FloatKind: 202 return protoreflect.ValueOfFloat32(32.32) 203 case protoreflect.DoubleKind: 204 return protoreflect.ValueOfFloat64(64.64) 205 case protoreflect.StringKind: 206 return protoreflect.ValueOfString(string("string")) 207 case protoreflect.BytesKind: 208 return protoreflect.ValueOfBytes([]byte("bytes")) 209 case protoreflect.EnumKind: 210 return protoreflect.ValueOfEnum(1) 211 default: 212 panic("unknown kind: " + k.String()) 213 } 214} 215