equals.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. // Copyright 2011 Aaron Jacobs. All Rights Reserved.
  2. // Author: aaronjjacobs@gmail.com (Aaron Jacobs)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. package oglematchers
  16. import (
  17. "errors"
  18. "fmt"
  19. "math"
  20. "reflect"
  21. )
  22. // Equals(x) returns a matcher that matches values v such that v and x are
  23. // equivalent. This includes the case when the comparison v == x using Go's
  24. // built-in comparison operator is legal (except for structs, which this
  25. // matcher does not support), but for convenience the following rules also
  26. // apply:
  27. //
  28. // * Type checking is done based on underlying types rather than actual
  29. // types, so that e.g. two aliases for string can be compared:
  30. //
  31. // type stringAlias1 string
  32. // type stringAlias2 string
  33. //
  34. // a := "taco"
  35. // b := stringAlias1("taco")
  36. // c := stringAlias2("taco")
  37. //
  38. // ExpectTrue(a == b) // Legal, passes
  39. // ExpectTrue(b == c) // Illegal, doesn't compile
  40. //
  41. // ExpectThat(a, Equals(b)) // Passes
  42. // ExpectThat(b, Equals(c)) // Passes
  43. //
  44. // * Values of numeric type are treated as if they were abstract numbers, and
  45. // compared accordingly. Therefore Equals(17) will match int(17),
  46. // int16(17), uint(17), float32(17), complex64(17), and so on.
  47. //
  48. // If you want a stricter matcher that contains no such cleverness, see
  49. // IdenticalTo instead.
  50. //
  51. // Arrays are supported by this matcher, but do not participate in the
  52. // exceptions above. Two arrays compared with this matcher must have identical
  53. // types, and their element type must itself be comparable according to Go's ==
  54. // operator.
  55. func Equals(x interface{}) Matcher {
  56. v := reflect.ValueOf(x)
  57. // This matcher doesn't support structs.
  58. if v.Kind() == reflect.Struct {
  59. panic(fmt.Sprintf("oglematchers.Equals: unsupported kind %v", v.Kind()))
  60. }
  61. // The == operator is not defined for non-nil slices.
  62. if v.Kind() == reflect.Slice && v.Pointer() != uintptr(0) {
  63. panic(fmt.Sprintf("oglematchers.Equals: non-nil slice"))
  64. }
  65. return &equalsMatcher{v}
  66. }
  67. type equalsMatcher struct {
  68. expectedValue reflect.Value
  69. }
  70. ////////////////////////////////////////////////////////////////////////
  71. // Numeric types
  72. ////////////////////////////////////////////////////////////////////////
  73. func isSignedInteger(v reflect.Value) bool {
  74. k := v.Kind()
  75. return k >= reflect.Int && k <= reflect.Int64
  76. }
  77. func isUnsignedInteger(v reflect.Value) bool {
  78. k := v.Kind()
  79. return k >= reflect.Uint && k <= reflect.Uintptr
  80. }
  81. func isInteger(v reflect.Value) bool {
  82. return isSignedInteger(v) || isUnsignedInteger(v)
  83. }
  84. func isFloat(v reflect.Value) bool {
  85. k := v.Kind()
  86. return k == reflect.Float32 || k == reflect.Float64
  87. }
  88. func isComplex(v reflect.Value) bool {
  89. k := v.Kind()
  90. return k == reflect.Complex64 || k == reflect.Complex128
  91. }
  92. func checkAgainstInt64(e int64, c reflect.Value) (err error) {
  93. err = errors.New("")
  94. switch {
  95. case isSignedInteger(c):
  96. if c.Int() == e {
  97. err = nil
  98. }
  99. case isUnsignedInteger(c):
  100. u := c.Uint()
  101. if u <= math.MaxInt64 && int64(u) == e {
  102. err = nil
  103. }
  104. // Turn around the various floating point types so that the checkAgainst*
  105. // functions for them can deal with precision issues.
  106. case isFloat(c), isComplex(c):
  107. return Equals(c.Interface()).Matches(e)
  108. default:
  109. err = NewFatalError("which is not numeric")
  110. }
  111. return
  112. }
  113. func checkAgainstUint64(e uint64, c reflect.Value) (err error) {
  114. err = errors.New("")
  115. switch {
  116. case isSignedInteger(c):
  117. i := c.Int()
  118. if i >= 0 && uint64(i) == e {
  119. err = nil
  120. }
  121. case isUnsignedInteger(c):
  122. if c.Uint() == e {
  123. err = nil
  124. }
  125. // Turn around the various floating point types so that the checkAgainst*
  126. // functions for them can deal with precision issues.
  127. case isFloat(c), isComplex(c):
  128. return Equals(c.Interface()).Matches(e)
  129. default:
  130. err = NewFatalError("which is not numeric")
  131. }
  132. return
  133. }
  134. func checkAgainstFloat32(e float32, c reflect.Value) (err error) {
  135. err = errors.New("")
  136. switch {
  137. case isSignedInteger(c):
  138. if float32(c.Int()) == e {
  139. err = nil
  140. }
  141. case isUnsignedInteger(c):
  142. if float32(c.Uint()) == e {
  143. err = nil
  144. }
  145. case isFloat(c):
  146. // Compare using float32 to avoid a false sense of precision; otherwise
  147. // e.g. Equals(float32(0.1)) won't match float32(0.1).
  148. if float32(c.Float()) == e {
  149. err = nil
  150. }
  151. case isComplex(c):
  152. comp := c.Complex()
  153. rl := real(comp)
  154. im := imag(comp)
  155. // Compare using float32 to avoid a false sense of precision; otherwise
  156. // e.g. Equals(float32(0.1)) won't match (0.1 + 0i).
  157. if im == 0 && float32(rl) == e {
  158. err = nil
  159. }
  160. default:
  161. err = NewFatalError("which is not numeric")
  162. }
  163. return
  164. }
  165. func checkAgainstFloat64(e float64, c reflect.Value) (err error) {
  166. err = errors.New("")
  167. ck := c.Kind()
  168. switch {
  169. case isSignedInteger(c):
  170. if float64(c.Int()) == e {
  171. err = nil
  172. }
  173. case isUnsignedInteger(c):
  174. if float64(c.Uint()) == e {
  175. err = nil
  176. }
  177. // If the actual value is lower precision, turn the comparison around so we
  178. // apply the low-precision rules. Otherwise, e.g. Equals(0.1) may not match
  179. // float32(0.1).
  180. case ck == reflect.Float32 || ck == reflect.Complex64:
  181. return Equals(c.Interface()).Matches(e)
  182. // Otherwise, compare with double precision.
  183. case isFloat(c):
  184. if c.Float() == e {
  185. err = nil
  186. }
  187. case isComplex(c):
  188. comp := c.Complex()
  189. rl := real(comp)
  190. im := imag(comp)
  191. if im == 0 && rl == e {
  192. err = nil
  193. }
  194. default:
  195. err = NewFatalError("which is not numeric")
  196. }
  197. return
  198. }
  199. func checkAgainstComplex64(e complex64, c reflect.Value) (err error) {
  200. err = errors.New("")
  201. realPart := real(e)
  202. imaginaryPart := imag(e)
  203. switch {
  204. case isInteger(c) || isFloat(c):
  205. // If we have no imaginary part, then we should just compare against the
  206. // real part. Otherwise, we can't be equal.
  207. if imaginaryPart != 0 {
  208. return
  209. }
  210. return checkAgainstFloat32(realPart, c)
  211. case isComplex(c):
  212. // Compare using complex64 to avoid a false sense of precision; otherwise
  213. // e.g. Equals(0.1 + 0i) won't match float32(0.1).
  214. if complex64(c.Complex()) == e {
  215. err = nil
  216. }
  217. default:
  218. err = NewFatalError("which is not numeric")
  219. }
  220. return
  221. }
  222. func checkAgainstComplex128(e complex128, c reflect.Value) (err error) {
  223. err = errors.New("")
  224. realPart := real(e)
  225. imaginaryPart := imag(e)
  226. switch {
  227. case isInteger(c) || isFloat(c):
  228. // If we have no imaginary part, then we should just compare against the
  229. // real part. Otherwise, we can't be equal.
  230. if imaginaryPart != 0 {
  231. return
  232. }
  233. return checkAgainstFloat64(realPart, c)
  234. case isComplex(c):
  235. if c.Complex() == e {
  236. err = nil
  237. }
  238. default:
  239. err = NewFatalError("which is not numeric")
  240. }
  241. return
  242. }
  243. ////////////////////////////////////////////////////////////////////////
  244. // Other types
  245. ////////////////////////////////////////////////////////////////////////
  246. func checkAgainstBool(e bool, c reflect.Value) (err error) {
  247. if c.Kind() != reflect.Bool {
  248. err = NewFatalError("which is not a bool")
  249. return
  250. }
  251. err = errors.New("")
  252. if c.Bool() == e {
  253. err = nil
  254. }
  255. return
  256. }
  257. func checkAgainstChan(e reflect.Value, c reflect.Value) (err error) {
  258. // Create a description of e's type, e.g. "chan int".
  259. typeStr := fmt.Sprintf("%s %s", e.Type().ChanDir(), e.Type().Elem())
  260. // Make sure c is a chan of the correct type.
  261. if c.Kind() != reflect.Chan ||
  262. c.Type().ChanDir() != e.Type().ChanDir() ||
  263. c.Type().Elem() != e.Type().Elem() {
  264. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  265. return
  266. }
  267. err = errors.New("")
  268. if c.Pointer() == e.Pointer() {
  269. err = nil
  270. }
  271. return
  272. }
  273. func checkAgainstFunc(e reflect.Value, c reflect.Value) (err error) {
  274. // Make sure c is a function.
  275. if c.Kind() != reflect.Func {
  276. err = NewFatalError("which is not a function")
  277. return
  278. }
  279. err = errors.New("")
  280. if c.Pointer() == e.Pointer() {
  281. err = nil
  282. }
  283. return
  284. }
  285. func checkAgainstMap(e reflect.Value, c reflect.Value) (err error) {
  286. // Make sure c is a map.
  287. if c.Kind() != reflect.Map {
  288. err = NewFatalError("which is not a map")
  289. return
  290. }
  291. err = errors.New("")
  292. if c.Pointer() == e.Pointer() {
  293. err = nil
  294. }
  295. return
  296. }
  297. func checkAgainstPtr(e reflect.Value, c reflect.Value) (err error) {
  298. // Create a description of e's type, e.g. "*int".
  299. typeStr := fmt.Sprintf("*%v", e.Type().Elem())
  300. // Make sure c is a pointer of the correct type.
  301. if c.Kind() != reflect.Ptr ||
  302. c.Type().Elem() != e.Type().Elem() {
  303. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  304. return
  305. }
  306. err = errors.New("")
  307. if c.Pointer() == e.Pointer() {
  308. err = nil
  309. }
  310. return
  311. }
  312. func checkAgainstSlice(e reflect.Value, c reflect.Value) (err error) {
  313. // Create a description of e's type, e.g. "[]int".
  314. typeStr := fmt.Sprintf("[]%v", e.Type().Elem())
  315. // Make sure c is a slice of the correct type.
  316. if c.Kind() != reflect.Slice ||
  317. c.Type().Elem() != e.Type().Elem() {
  318. err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr))
  319. return
  320. }
  321. err = errors.New("")
  322. if c.Pointer() == e.Pointer() {
  323. err = nil
  324. }
  325. return
  326. }
  327. func checkAgainstString(e reflect.Value, c reflect.Value) (err error) {
  328. // Make sure c is a string.
  329. if c.Kind() != reflect.String {
  330. err = NewFatalError("which is not a string")
  331. return
  332. }
  333. err = errors.New("")
  334. if c.String() == e.String() {
  335. err = nil
  336. }
  337. return
  338. }
  339. func checkAgainstArray(e reflect.Value, c reflect.Value) (err error) {
  340. // Create a description of e's type, e.g. "[2]int".
  341. typeStr := fmt.Sprintf("%v", e.Type())
  342. // Make sure c is the correct type.
  343. if c.Type() != e.Type() {
  344. err = NewFatalError(fmt.Sprintf("which is not %s", typeStr))
  345. return
  346. }
  347. // Check for equality.
  348. if e.Interface() != c.Interface() {
  349. err = errors.New("")
  350. return
  351. }
  352. return
  353. }
  354. func checkAgainstUnsafePointer(e reflect.Value, c reflect.Value) (err error) {
  355. // Make sure c is a pointer.
  356. if c.Kind() != reflect.UnsafePointer {
  357. err = NewFatalError("which is not a unsafe.Pointer")
  358. return
  359. }
  360. err = errors.New("")
  361. if c.Pointer() == e.Pointer() {
  362. err = nil
  363. }
  364. return
  365. }
  366. func checkForNil(c reflect.Value) (err error) {
  367. err = errors.New("")
  368. // Make sure it is legal to call IsNil.
  369. switch c.Kind() {
  370. case reflect.Invalid:
  371. case reflect.Chan:
  372. case reflect.Func:
  373. case reflect.Interface:
  374. case reflect.Map:
  375. case reflect.Ptr:
  376. case reflect.Slice:
  377. default:
  378. err = NewFatalError("which cannot be compared to nil")
  379. return
  380. }
  381. // Ask whether the value is nil. Handle a nil literal (kind Invalid)
  382. // specially, since it's not legal to call IsNil there.
  383. if c.Kind() == reflect.Invalid || c.IsNil() {
  384. err = nil
  385. }
  386. return
  387. }
  388. ////////////////////////////////////////////////////////////////////////
  389. // Public implementation
  390. ////////////////////////////////////////////////////////////////////////
  391. func (m *equalsMatcher) Matches(candidate interface{}) error {
  392. e := m.expectedValue
  393. c := reflect.ValueOf(candidate)
  394. ek := e.Kind()
  395. switch {
  396. case ek == reflect.Bool:
  397. return checkAgainstBool(e.Bool(), c)
  398. case isSignedInteger(e):
  399. return checkAgainstInt64(e.Int(), c)
  400. case isUnsignedInteger(e):
  401. return checkAgainstUint64(e.Uint(), c)
  402. case ek == reflect.Float32:
  403. return checkAgainstFloat32(float32(e.Float()), c)
  404. case ek == reflect.Float64:
  405. return checkAgainstFloat64(e.Float(), c)
  406. case ek == reflect.Complex64:
  407. return checkAgainstComplex64(complex64(e.Complex()), c)
  408. case ek == reflect.Complex128:
  409. return checkAgainstComplex128(complex128(e.Complex()), c)
  410. case ek == reflect.Chan:
  411. return checkAgainstChan(e, c)
  412. case ek == reflect.Func:
  413. return checkAgainstFunc(e, c)
  414. case ek == reflect.Map:
  415. return checkAgainstMap(e, c)
  416. case ek == reflect.Ptr:
  417. return checkAgainstPtr(e, c)
  418. case ek == reflect.Slice:
  419. return checkAgainstSlice(e, c)
  420. case ek == reflect.String:
  421. return checkAgainstString(e, c)
  422. case ek == reflect.Array:
  423. return checkAgainstArray(e, c)
  424. case ek == reflect.UnsafePointer:
  425. return checkAgainstUnsafePointer(e, c)
  426. case ek == reflect.Invalid:
  427. return checkForNil(c)
  428. }
  429. panic(fmt.Sprintf("equalsMatcher.Matches: unexpected kind: %v", ek))
  430. }
  431. func (m *equalsMatcher) Description() string {
  432. // Special case: handle nil.
  433. if !m.expectedValue.IsValid() {
  434. return "is nil"
  435. }
  436. return fmt.Sprintf("%v", m.expectedValue.Interface())
  437. }