extensions.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2010 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto
  32. /*
  33. * Types and routines for supporting protocol buffer extensions.
  34. */
  35. import (
  36. "errors"
  37. "fmt"
  38. "reflect"
  39. "strconv"
  40. "sync"
  41. )
  42. // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
  43. var ErrMissingExtension = errors.New("proto: missing extension")
  44. // ExtensionRange represents a range of message extensions for a protocol buffer.
  45. // Used in code generated by the protocol compiler.
  46. type ExtensionRange struct {
  47. Start, End int32 // both inclusive
  48. }
  49. // extendableProto is an interface implemented by any protocol buffer generated by the current
  50. // proto compiler that may be extended.
  51. type extendableProto interface {
  52. Message
  53. ExtensionRangeArray() []ExtensionRange
  54. extensionsWrite() map[int32]Extension
  55. extensionsRead() (map[int32]Extension, sync.Locker)
  56. }
  57. // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
  58. // version of the proto compiler that may be extended.
  59. type extendableProtoV1 interface {
  60. Message
  61. ExtensionRangeArray() []ExtensionRange
  62. ExtensionMap() map[int32]Extension
  63. }
  64. // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
  65. type extensionAdapter struct {
  66. extendableProtoV1
  67. }
  68. func (e extensionAdapter) extensionsWrite() map[int32]Extension {
  69. return e.ExtensionMap()
  70. }
  71. func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
  72. return e.ExtensionMap(), notLocker{}
  73. }
  74. // notLocker is a sync.Locker whose Lock and Unlock methods are nops.
  75. type notLocker struct{}
  76. func (n notLocker) Lock() {}
  77. func (n notLocker) Unlock() {}
  78. // extendable returns the extendableProto interface for the given generated proto message.
  79. // If the proto message has the old extension format, it returns a wrapper that implements
  80. // the extendableProto interface.
  81. func extendable(p interface{}) (extendableProto, bool) {
  82. if ep, ok := p.(extendableProto); ok {
  83. return ep, ok
  84. }
  85. if ep, ok := p.(extendableProtoV1); ok {
  86. return extensionAdapter{ep}, ok
  87. }
  88. return nil, false
  89. }
  90. // XXX_InternalExtensions is an internal representation of proto extensions.
  91. //
  92. // Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
  93. // thus gaining the unexported 'extensions' method, which can be called only from the proto package.
  94. //
  95. // The methods of XXX_InternalExtensions are not concurrency safe in general,
  96. // but calls to logically read-only methods such as has and get may be executed concurrently.
  97. type XXX_InternalExtensions struct {
  98. // The struct must be indirect so that if a user inadvertently copies a
  99. // generated message and its embedded XXX_InternalExtensions, they
  100. // avoid the mayhem of a copied mutex.
  101. //
  102. // The mutex serializes all logically read-only operations to p.extensionMap.
  103. // It is up to the client to ensure that write operations to p.extensionMap are
  104. // mutually exclusive with other accesses.
  105. p *struct {
  106. mu sync.Mutex
  107. extensionMap map[int32]Extension
  108. }
  109. }
  110. // extensionsWrite returns the extension map, creating it on first use.
  111. func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
  112. if e.p == nil {
  113. e.p = new(struct {
  114. mu sync.Mutex
  115. extensionMap map[int32]Extension
  116. })
  117. e.p.extensionMap = make(map[int32]Extension)
  118. }
  119. return e.p.extensionMap
  120. }
  121. // extensionsRead returns the extensions map for read-only use. It may be nil.
  122. // The caller must hold the returned mutex's lock when accessing Elements within the map.
  123. func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
  124. if e.p == nil {
  125. return nil, nil
  126. }
  127. return e.p.extensionMap, &e.p.mu
  128. }
  129. var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
  130. var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
  131. // ExtensionDesc represents an extension specification.
  132. // Used in generated code from the protocol compiler.
  133. type ExtensionDesc struct {
  134. ExtendedType Message // nil pointer to the type that is being extended
  135. ExtensionType interface{} // nil pointer to the extension type
  136. Field int32 // field number
  137. Name string // fully-qualified name of extension, for text formatting
  138. Tag string // protobuf tag style
  139. Filename string // name of the file in which the extension is defined
  140. }
  141. func (ed *ExtensionDesc) repeated() bool {
  142. t := reflect.TypeOf(ed.ExtensionType)
  143. return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
  144. }
  145. // Extension represents an extension in a message.
  146. type Extension struct {
  147. // When an extension is stored in a message using SetExtension
  148. // only desc and value are set. When the message is marshaled
  149. // enc will be set to the encoded form of the message.
  150. //
  151. // When a message is unmarshaled and contains extensions, each
  152. // extension will have only enc set. When such an extension is
  153. // accessed using GetExtension (or GetExtensions) desc and value
  154. // will be set.
  155. desc *ExtensionDesc
  156. value interface{}
  157. enc []byte
  158. }
  159. // SetRawExtension is for testing only.
  160. func SetRawExtension(base Message, id int32, b []byte) {
  161. epb, ok := extendable(base)
  162. if !ok {
  163. return
  164. }
  165. extmap := epb.extensionsWrite()
  166. extmap[id] = Extension{enc: b}
  167. }
  168. // isExtensionField returns true iff the given field number is in an extension range.
  169. func isExtensionField(pb extendableProto, field int32) bool {
  170. for _, er := range pb.ExtensionRangeArray() {
  171. if er.Start <= field && field <= er.End {
  172. return true
  173. }
  174. }
  175. return false
  176. }
  177. // checkExtensionTypes checks that the given extension is valid for pb.
  178. func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
  179. var pbi interface{} = pb
  180. // Check the extended type.
  181. if ea, ok := pbi.(extensionAdapter); ok {
  182. pbi = ea.extendableProtoV1
  183. }
  184. if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
  185. return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
  186. }
  187. // Check the range.
  188. if !isExtensionField(pb, extension.Field) {
  189. return errors.New("proto: bad extension number; not in declared ranges")
  190. }
  191. return nil
  192. }
  193. // extPropKey is sufficient to uniquely identify an extension.
  194. type extPropKey struct {
  195. base reflect.Type
  196. field int32
  197. }
  198. var extProp = struct {
  199. sync.RWMutex
  200. m map[extPropKey]*Properties
  201. }{
  202. m: make(map[extPropKey]*Properties),
  203. }
  204. func extensionProperties(ed *ExtensionDesc) *Properties {
  205. key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
  206. extProp.RLock()
  207. if prop, ok := extProp.m[key]; ok {
  208. extProp.RUnlock()
  209. return prop
  210. }
  211. extProp.RUnlock()
  212. extProp.Lock()
  213. defer extProp.Unlock()
  214. // Check again.
  215. if prop, ok := extProp.m[key]; ok {
  216. return prop
  217. }
  218. prop := new(Properties)
  219. prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
  220. extProp.m[key] = prop
  221. return prop
  222. }
  223. // encode encodes any unmarshaled (unencoded) extensions in e.
  224. func encodeExtensions(e *XXX_InternalExtensions) error {
  225. m, mu := e.extensionsRead()
  226. if m == nil {
  227. return nil // fast path
  228. }
  229. mu.Lock()
  230. defer mu.Unlock()
  231. return encodeExtensionsMap(m)
  232. }
  233. // encode encodes any unmarshaled (unencoded) extensions in e.
  234. func encodeExtensionsMap(m map[int32]Extension) error {
  235. for k, e := range m {
  236. if e.value == nil || e.desc == nil {
  237. // Extension is only in its encoded form.
  238. continue
  239. }
  240. // We don't skip extensions that have an encoded form set,
  241. // because the extension value may have been mutated after
  242. // the last time this function was called.
  243. et := reflect.TypeOf(e.desc.ExtensionType)
  244. props := extensionProperties(e.desc)
  245. p := NewBuffer(nil)
  246. // If e.value has type T, the encoder expects a *struct{ X T }.
  247. // Pass a *T with a zero field and hope it all works out.
  248. x := reflect.New(et)
  249. x.Elem().Set(reflect.ValueOf(e.value))
  250. if err := props.enc(p, props, toStructPointer(x)); err != nil {
  251. return err
  252. }
  253. e.enc = p.buf
  254. m[k] = e
  255. }
  256. return nil
  257. }
  258. func extensionsSize(e *XXX_InternalExtensions) (n int) {
  259. m, mu := e.extensionsRead()
  260. if m == nil {
  261. return 0
  262. }
  263. mu.Lock()
  264. defer mu.Unlock()
  265. return extensionsMapSize(m)
  266. }
  267. func extensionsMapSize(m map[int32]Extension) (n int) {
  268. for _, e := range m {
  269. if e.value == nil || e.desc == nil {
  270. // Extension is only in its encoded form.
  271. n += len(e.enc)
  272. continue
  273. }
  274. // We don't skip extensions that have an encoded form set,
  275. // because the extension value may have been mutated after
  276. // the last time this function was called.
  277. et := reflect.TypeOf(e.desc.ExtensionType)
  278. props := extensionProperties(e.desc)
  279. // If e.value has type T, the encoder expects a *struct{ X T }.
  280. // Pass a *T with a zero field and hope it all works out.
  281. x := reflect.New(et)
  282. x.Elem().Set(reflect.ValueOf(e.value))
  283. n += props.size(props, toStructPointer(x))
  284. }
  285. return
  286. }
  287. // HasExtension returns whether the given extension is present in pb.
  288. func HasExtension(pb Message, extension *ExtensionDesc) bool {
  289. // TODO: Check types, field numbers, etc.?
  290. epb, ok := extendable(pb)
  291. if !ok {
  292. return false
  293. }
  294. extmap, mu := epb.extensionsRead()
  295. if extmap == nil {
  296. return false
  297. }
  298. mu.Lock()
  299. _, ok = extmap[extension.Field]
  300. mu.Unlock()
  301. return ok
  302. }
  303. // ClearExtension removes the given extension from pb.
  304. func ClearExtension(pb Message, extension *ExtensionDesc) {
  305. epb, ok := extendable(pb)
  306. if !ok {
  307. return
  308. }
  309. // TODO: Check types, field numbers, etc.?
  310. extmap := epb.extensionsWrite()
  311. delete(extmap, extension.Field)
  312. }
  313. // GetExtension parses and returns the given extension of pb.
  314. // If the extension is not present and has no default value it returns ErrMissingExtension.
  315. func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
  316. epb, ok := extendable(pb)
  317. if !ok {
  318. return nil, errors.New("proto: not an extendable proto")
  319. }
  320. if err := checkExtensionTypes(epb, extension); err != nil {
  321. return nil, err
  322. }
  323. emap, mu := epb.extensionsRead()
  324. if emap == nil {
  325. return defaultExtensionValue(extension)
  326. }
  327. mu.Lock()
  328. defer mu.Unlock()
  329. e, ok := emap[extension.Field]
  330. if !ok {
  331. // defaultExtensionValue returns the default value or
  332. // ErrMissingExtension if there is no default.
  333. return defaultExtensionValue(extension)
  334. }
  335. if e.value != nil {
  336. // Already decoded. Check the descriptor, though.
  337. if e.desc != extension {
  338. // This shouldn't happen. If it does, it means that
  339. // GetExtension was called twice with two different
  340. // descriptors with the same field number.
  341. return nil, errors.New("proto: descriptor conflict")
  342. }
  343. return e.value, nil
  344. }
  345. v, err := decodeExtension(e.enc, extension)
  346. if err != nil {
  347. return nil, err
  348. }
  349. // Remember the decoded version and drop the encoded version.
  350. // That way it is safe to mutate what we return.
  351. e.value = v
  352. e.desc = extension
  353. e.enc = nil
  354. emap[extension.Field] = e
  355. return e.value, nil
  356. }
  357. // defaultExtensionValue returns the default value for extension.
  358. // If no default for an extension is defined ErrMissingExtension is returned.
  359. func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
  360. t := reflect.TypeOf(extension.ExtensionType)
  361. props := extensionProperties(extension)
  362. sf, _, err := fieldDefault(t, props)
  363. if err != nil {
  364. return nil, err
  365. }
  366. if sf == nil || sf.value == nil {
  367. // There is no default value.
  368. return nil, ErrMissingExtension
  369. }
  370. if t.Kind() != reflect.Ptr {
  371. // We do not need to return a Ptr, we can directly return sf.value.
  372. return sf.value, nil
  373. }
  374. // We need to return an interface{} that is a pointer to sf.value.
  375. value := reflect.New(t).Elem()
  376. value.Set(reflect.New(value.Type().Elem()))
  377. if sf.kind == reflect.Int32 {
  378. // We may have an int32 or an enum, but the underlying data is int32.
  379. // Since we can't set an int32 into a non int32 reflect.value directly
  380. // set it as a int32.
  381. value.Elem().SetInt(int64(sf.value.(int32)))
  382. } else {
  383. value.Elem().Set(reflect.ValueOf(sf.value))
  384. }
  385. return value.Interface(), nil
  386. }
  387. // decodeExtension decodes an extension encoded in b.
  388. func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
  389. o := NewBuffer(b)
  390. t := reflect.TypeOf(extension.ExtensionType)
  391. props := extensionProperties(extension)
  392. // t is a pointer to a struct, pointer to basic type or a slice.
  393. // Allocate a "field" to store the pointer/slice itself; the
  394. // pointer/slice will be stored here. We pass
  395. // the address of this field to props.dec.
  396. // This passes a zero field and a *t and lets props.dec
  397. // interpret it as a *struct{ x t }.
  398. value := reflect.New(t).Elem()
  399. for {
  400. // Discard wire type and field number varint. It isn't needed.
  401. if _, err := o.DecodeVarint(); err != nil {
  402. return nil, err
  403. }
  404. if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
  405. return nil, err
  406. }
  407. if o.index >= len(o.buf) {
  408. break
  409. }
  410. }
  411. return value.Interface(), nil
  412. }
  413. // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
  414. // The returned slice has the same length as es; missing extensions will appear as nil elements.
  415. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
  416. epb, ok := extendable(pb)
  417. if !ok {
  418. return nil, errors.New("proto: not an extendable proto")
  419. }
  420. extensions = make([]interface{}, len(es))
  421. for i, e := range es {
  422. extensions[i], err = GetExtension(epb, e)
  423. if err == ErrMissingExtension {
  424. err = nil
  425. }
  426. if err != nil {
  427. return
  428. }
  429. }
  430. return
  431. }
  432. // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
  433. // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
  434. // just the Field field, which defines the extension's field number.
  435. func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
  436. epb, ok := extendable(pb)
  437. if !ok {
  438. return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
  439. }
  440. registeredExtensions := RegisteredExtensions(pb)
  441. emap, mu := epb.extensionsRead()
  442. if emap == nil {
  443. return nil, nil
  444. }
  445. mu.Lock()
  446. defer mu.Unlock()
  447. extensions := make([]*ExtensionDesc, 0, len(emap))
  448. for extid, e := range emap {
  449. desc := e.desc
  450. if desc == nil {
  451. desc = registeredExtensions[extid]
  452. if desc == nil {
  453. desc = &ExtensionDesc{Field: extid}
  454. }
  455. }
  456. extensions = append(extensions, desc)
  457. }
  458. return extensions, nil
  459. }
  460. // SetExtension sets the specified extension of pb to the specified value.
  461. func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
  462. epb, ok := extendable(pb)
  463. if !ok {
  464. return errors.New("proto: not an extendable proto")
  465. }
  466. if err := checkExtensionTypes(epb, extension); err != nil {
  467. return err
  468. }
  469. typ := reflect.TypeOf(extension.ExtensionType)
  470. if typ != reflect.TypeOf(value) {
  471. return errors.New("proto: bad extension value type")
  472. }
  473. // nil extension values need to be caught early, because the
  474. // encoder can't distinguish an ErrNil due to a nil extension
  475. // from an ErrNil due to a missing field. Extensions are
  476. // always optional, so the encoder would just swallow the error
  477. // and drop all the extensions from the encoded message.
  478. if reflect.ValueOf(value).IsNil() {
  479. return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
  480. }
  481. extmap := epb.extensionsWrite()
  482. extmap[extension.Field] = Extension{desc: extension, value: value}
  483. return nil
  484. }
  485. // ClearAllExtensions clears all extensions from pb.
  486. func ClearAllExtensions(pb Message) {
  487. epb, ok := extendable(pb)
  488. if !ok {
  489. return
  490. }
  491. m := epb.extensionsWrite()
  492. for k := range m {
  493. delete(m, k)
  494. }
  495. }
  496. // A global registry of extensions.
  497. // The generated code will register the generated descriptors by calling RegisterExtension.
  498. var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
  499. // RegisterExtension is called from the generated code.
  500. func RegisterExtension(desc *ExtensionDesc) {
  501. st := reflect.TypeOf(desc.ExtendedType).Elem()
  502. m := extensionMaps[st]
  503. if m == nil {
  504. m = make(map[int32]*ExtensionDesc)
  505. extensionMaps[st] = m
  506. }
  507. if _, ok := m[desc.Field]; ok {
  508. panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
  509. }
  510. m[desc.Field] = desc
  511. }
  512. // RegisteredExtensions returns a map of the registered extensions of a
  513. // protocol buffer struct, indexed by the extension number.
  514. // The argument pb should be a nil pointer to the struct type.
  515. func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
  516. return extensionMaps[reflect.TypeOf(pb).Elem()]
  517. }