extension.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/reflect/protoreflect"
  7. )
  8. // HasExtension reports whether an extension field is populated.
  9. // It returns false if m is invalid or if xt does not extend m.
  10. func HasExtension(m Message, xt protoreflect.ExtensionType) bool {
  11. // Treat nil message interface as an empty message; no populated fields.
  12. if m == nil {
  13. return false
  14. }
  15. // As a special-case, we reports invalid or mismatching descriptors
  16. // as always not being populated (since they aren't).
  17. if xt == nil || m.ProtoReflect().Descriptor() != xt.TypeDescriptor().ContainingMessage() {
  18. return false
  19. }
  20. return m.ProtoReflect().Has(xt.TypeDescriptor())
  21. }
  22. // ClearExtension clears an extension field such that subsequent
  23. // HasExtension calls return false.
  24. // It panics if m is invalid or if xt does not extend m.
  25. func ClearExtension(m Message, xt protoreflect.ExtensionType) {
  26. m.ProtoReflect().Clear(xt.TypeDescriptor())
  27. }
  28. // GetExtension retrieves the value for an extension field.
  29. // If the field is unpopulated, it returns the default value for
  30. // scalars and an immutable, empty value for lists or messages.
  31. // It panics if xt does not extend m.
  32. func GetExtension(m Message, xt protoreflect.ExtensionType) interface{} {
  33. // Treat nil message interface as an empty message; return the default.
  34. if m == nil {
  35. return xt.InterfaceOf(xt.Zero())
  36. }
  37. return xt.InterfaceOf(m.ProtoReflect().Get(xt.TypeDescriptor()))
  38. }
  39. // SetExtension stores the value of an extension field.
  40. // It panics if m is invalid, xt does not extend m, or if type of v
  41. // is invalid for the specified extension field.
  42. func SetExtension(m Message, xt protoreflect.ExtensionType, v interface{}) {
  43. xd := xt.TypeDescriptor()
  44. pv := xt.ValueOf(v)
  45. // Specially treat an invalid list, map, or message as clear.
  46. isValid := true
  47. switch {
  48. case xd.IsList():
  49. isValid = pv.List().IsValid()
  50. case xd.IsMap():
  51. isValid = pv.Map().IsValid()
  52. case xd.Message() != nil:
  53. isValid = pv.Message().IsValid()
  54. }
  55. if !isValid {
  56. m.ProtoReflect().Clear(xd)
  57. return
  58. }
  59. m.ProtoReflect().Set(xd, pv)
  60. }
  61. // RangeExtensions iterates over every populated extension field in m in an
  62. // undefined order, calling f for each extension type and value encountered.
  63. // It returns immediately if f returns false.
  64. // While iterating, mutating operations may only be performed
  65. // on the current extension field.
  66. func RangeExtensions(m Message, f func(protoreflect.ExtensionType, interface{}) bool) {
  67. // Treat nil message interface as an empty message; nothing to range over.
  68. if m == nil {
  69. return
  70. }
  71. m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  72. if fd.IsExtension() {
  73. xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
  74. vi := xt.InterfaceOf(v)
  75. return f(xt, vi)
  76. }
  77. return true
  78. })
  79. }