equal.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "math"
  8. "reflect"
  9. "google.golang.org/protobuf/encoding/protowire"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. // Equal reports whether two messages are equal.
  13. // If two messages marshal to the same bytes under deterministic serialization,
  14. // then Equal is guaranteed to report true.
  15. //
  16. // Two messages are equal if they belong to the same message descriptor,
  17. // have the same set of populated known and extension field values,
  18. // and the same set of unknown fields values. If either of the top-level
  19. // messages are invalid, then Equal reports true only if both are invalid.
  20. //
  21. // Scalar values are compared with the equivalent of the == operator in Go,
  22. // except bytes values which are compared using bytes.Equal and
  23. // floating point values which specially treat NaNs as equal.
  24. // Message values are compared by recursively calling Equal.
  25. // Lists are equal if each element value is also equal.
  26. // Maps are equal if they have the same set of keys, where the pair of values
  27. // for each key is also equal.
  28. func Equal(x, y Message) bool {
  29. if x == nil || y == nil {
  30. return x == nil && y == nil
  31. }
  32. if reflect.TypeOf(x).Kind() == reflect.Ptr && x == y {
  33. // Avoid an expensive comparison if both inputs are identical pointers.
  34. return true
  35. }
  36. mx := x.ProtoReflect()
  37. my := y.ProtoReflect()
  38. if mx.IsValid() != my.IsValid() {
  39. return false
  40. }
  41. return equalMessage(mx, my)
  42. }
  43. // equalMessage compares two messages.
  44. func equalMessage(mx, my protoreflect.Message) bool {
  45. if mx.Descriptor() != my.Descriptor() {
  46. return false
  47. }
  48. nx := 0
  49. equal := true
  50. mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
  51. nx++
  52. vy := my.Get(fd)
  53. equal = my.Has(fd) && equalField(fd, vx, vy)
  54. return equal
  55. })
  56. if !equal {
  57. return false
  58. }
  59. ny := 0
  60. my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
  61. ny++
  62. return true
  63. })
  64. if nx != ny {
  65. return false
  66. }
  67. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  68. }
  69. // equalField compares two fields.
  70. func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
  71. switch {
  72. case fd.IsList():
  73. return equalList(fd, x.List(), y.List())
  74. case fd.IsMap():
  75. return equalMap(fd, x.Map(), y.Map())
  76. default:
  77. return equalValue(fd, x, y)
  78. }
  79. }
  80. // equalMap compares two maps.
  81. func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool {
  82. if x.Len() != y.Len() {
  83. return false
  84. }
  85. equal := true
  86. x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
  87. vy := y.Get(k)
  88. equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
  89. return equal
  90. })
  91. return equal
  92. }
  93. // equalList compares two lists.
  94. func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool {
  95. if x.Len() != y.Len() {
  96. return false
  97. }
  98. for i := x.Len() - 1; i >= 0; i-- {
  99. if !equalValue(fd, x.Get(i), y.Get(i)) {
  100. return false
  101. }
  102. }
  103. return true
  104. }
  105. // equalValue compares two singular values.
  106. func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
  107. switch fd.Kind() {
  108. case protoreflect.BoolKind:
  109. return x.Bool() == y.Bool()
  110. case protoreflect.EnumKind:
  111. return x.Enum() == y.Enum()
  112. case protoreflect.Int32Kind, protoreflect.Sint32Kind,
  113. protoreflect.Int64Kind, protoreflect.Sint64Kind,
  114. protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind:
  115. return x.Int() == y.Int()
  116. case protoreflect.Uint32Kind, protoreflect.Uint64Kind,
  117. protoreflect.Fixed32Kind, protoreflect.Fixed64Kind:
  118. return x.Uint() == y.Uint()
  119. case protoreflect.FloatKind, protoreflect.DoubleKind:
  120. fx := x.Float()
  121. fy := y.Float()
  122. if math.IsNaN(fx) || math.IsNaN(fy) {
  123. return math.IsNaN(fx) && math.IsNaN(fy)
  124. }
  125. return fx == fy
  126. case protoreflect.StringKind:
  127. return x.String() == y.String()
  128. case protoreflect.BytesKind:
  129. return bytes.Equal(x.Bytes(), y.Bytes())
  130. case protoreflect.MessageKind, protoreflect.GroupKind:
  131. return equalMessage(x.Message(), y.Message())
  132. default:
  133. return x.Interface() == y.Interface()
  134. }
  135. }
  136. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  137. // of each individual field number.
  138. func equalUnknown(x, y protoreflect.RawFields) bool {
  139. if len(x) != len(y) {
  140. return false
  141. }
  142. if bytes.Equal([]byte(x), []byte(y)) {
  143. return true
  144. }
  145. mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
  146. my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
  147. for len(x) > 0 {
  148. fnum, _, n := protowire.ConsumeField(x)
  149. mx[fnum] = append(mx[fnum], x[:n]...)
  150. x = x[n:]
  151. }
  152. for len(y) > 0 {
  153. fnum, _, n := protowire.ConsumeField(y)
  154. my[fnum] = append(my[fnum], y[:n]...)
  155. y = y[n:]
  156. }
  157. return reflect.DeepEqual(mx, my)
  158. }