checkinit.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/internal/errors"
  7. "google.golang.org/protobuf/reflect/protoreflect"
  8. "google.golang.org/protobuf/runtime/protoiface"
  9. )
  10. // CheckInitialized returns an error if any required fields in m are not set.
  11. func CheckInitialized(m Message) error {
  12. // Treat a nil message interface as an "untyped" empty message,
  13. // which we assume to have no required fields.
  14. if m == nil {
  15. return nil
  16. }
  17. return checkInitialized(m.ProtoReflect())
  18. }
  19. // CheckInitialized returns an error if any required fields in m are not set.
  20. func checkInitialized(m protoreflect.Message) error {
  21. if methods := protoMethods(m); methods != nil && methods.CheckInitialized != nil {
  22. _, err := methods.CheckInitialized(protoiface.CheckInitializedInput{
  23. Message: m,
  24. })
  25. return err
  26. }
  27. return checkInitializedSlow(m)
  28. }
  29. func checkInitializedSlow(m protoreflect.Message) error {
  30. md := m.Descriptor()
  31. fds := md.Fields()
  32. for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
  33. fd := fds.ByNumber(nums.Get(i))
  34. if !m.Has(fd) {
  35. return errors.RequiredNotSet(string(fd.FullName()))
  36. }
  37. }
  38. var err error
  39. m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  40. switch {
  41. case fd.IsList():
  42. if fd.Message() == nil {
  43. return true
  44. }
  45. for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
  46. err = checkInitialized(list.Get(i).Message())
  47. }
  48. case fd.IsMap():
  49. if fd.MapValue().Message() == nil {
  50. return true
  51. }
  52. v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool {
  53. err = checkInitialized(v.Message())
  54. return err == nil
  55. })
  56. default:
  57. if fd.Message() == nil {
  58. return true
  59. }
  60. err = checkInitialized(v.Message())
  61. }
  62. return err == nil
  63. })
  64. return err
  65. }