intrinsics_test.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. // License: GPLv3 Copyright: 2024, Kovid Goyal, <kovid at kovidgoyal.net>
  2. package simdstring
  3. import (
  4. "bytes"
  5. "fmt"
  6. "kitty/tools/utils"
  7. "runtime"
  8. "strings"
  9. "testing"
  10. "unsafe"
  11. "github.com/google/go-cmp/cmp"
  12. )
  13. var _ = fmt.Print
  14. func test_load(src []byte) []byte {
  15. ans := make([]byte, len(src))
  16. if len(src) == 16 {
  17. test_load_asm_128(src, ans)
  18. } else {
  19. test_load_asm_256(src, ans)
  20. }
  21. return ans
  22. }
  23. func test_set1_epi8(b byte, sz int) []byte {
  24. ans := make([]byte, sz)
  25. if sz == 16 {
  26. test_set1_epi8_asm_128(b, ans)
  27. } else {
  28. test_set1_epi8_asm_256(b, ans)
  29. }
  30. return ans
  31. }
  32. func test_cmpeq_epi8(a, b []byte) []byte {
  33. ans := make([]byte, len(a))
  34. if len(ans) == 16 {
  35. test_cmpeq_epi8_asm_128(a, b, ans)
  36. } else {
  37. test_cmpeq_epi8_asm_256(a, b, ans)
  38. }
  39. return ans
  40. }
  41. func test_cmplt_epi8(t *testing.T, a, b []byte) []byte {
  42. ans := make([]byte, len(a))
  43. var prev []byte
  44. for which := 0; which < 3; which++ {
  45. if len(ans) == 16 {
  46. test_cmplt_epi8_asm_128(a, b, which, ans)
  47. } else {
  48. test_cmplt_epi8_asm_256(a, b, which, ans)
  49. }
  50. if prev != nil {
  51. if s := cmp.Diff(prev, ans); s != "" {
  52. t.Fatalf("cmplt returned different result for which=%d\n%s", which, s)
  53. }
  54. }
  55. prev = bytes.Clone(ans)
  56. }
  57. return ans
  58. }
  59. func test_or(a, b []byte) []byte {
  60. ans := make([]byte, len(a))
  61. if len(ans) == 16 {
  62. test_or_asm_128(a, b, ans)
  63. } else {
  64. test_or_asm_256(a, b, ans)
  65. }
  66. return ans
  67. }
  68. func test_jump_if_zero(a []byte) int {
  69. if len(a) == 16 {
  70. return test_jump_if_zero_asm_128(a)
  71. }
  72. return test_jump_if_zero_asm_256(a)
  73. }
  74. func test_count_to_match(a []byte, b byte) int {
  75. if len(a) == 16 {
  76. return test_count_to_match_asm_128(a, b)
  77. }
  78. return test_count_to_match_asm_256(a, b)
  79. }
  80. func ordered_bytes(size int) []byte {
  81. ans := make([]byte, size)
  82. for i := range ans {
  83. ans[i] = byte(i)
  84. }
  85. return ans
  86. }
  87. func broadcast_byte(b byte, size int) []byte {
  88. return bytes.Repeat([]byte{b}, size)
  89. }
  90. func get_sizes(t *testing.T) []int {
  91. sizes := []int{}
  92. if Have128bit {
  93. sizes = append(sizes, 16)
  94. }
  95. if Have256bit {
  96. sizes = append(sizes, 32)
  97. }
  98. if len(sizes) == 0 {
  99. t.Skip("skipping as no SIMD available at runtime")
  100. }
  101. return sizes
  102. }
  103. func addressof_data(b []byte) uintptr {
  104. return uintptr(unsafe.Pointer(&b[0]))
  105. }
  106. func aligned_slice(sz, alignment int) ([]byte, []byte) {
  107. ans := make([]byte, sz+alignment+512)
  108. a := addressof_data(ans)
  109. a &= uintptr(alignment - 1)
  110. extra := uintptr(alignment) - a
  111. utils.Memset(ans, '<')
  112. utils.Memset(ans[extra+uintptr(sz):], '>')
  113. return ans[extra : extra+uintptr(sz)], ans
  114. }
  115. func TestSIMDStringOps(t *testing.T) {
  116. sizes := get_sizes(t)
  117. test := func(haystack []byte, a, b byte, align_offset int) {
  118. var actual int
  119. sh, _ := aligned_slice(len(haystack)+align_offset, 64)
  120. sh = sh[align_offset:]
  121. copy(sh, haystack)
  122. haystack = sh
  123. expected := index_byte2_scalar(haystack, a, b)
  124. for _, sz := range sizes {
  125. switch sz {
  126. case 16:
  127. actual = index_byte2_asm_128(haystack, a, b)
  128. case 32:
  129. actual = index_byte2_asm_256(haystack, a, b)
  130. }
  131. if actual != expected {
  132. t.Fatalf("Failed to find '%c' or '%c' in: %#v at align: %d (expected: %d != actual: %d) at size: %d",
  133. a, b, string(haystack), addressof_data(haystack)&uintptr(sz-1), expected, actual, sz)
  134. }
  135. }
  136. }
  137. // test alignment issues
  138. q := []byte("abc")
  139. for sz := 0; sz < 32; sz++ {
  140. test(q, '<', '>', sz)
  141. test(q, ' ', 'b', sz)
  142. test(q, '<', 'a', sz)
  143. test(q, '<', 'b', sz)
  144. test(q, 'c', '>', sz)
  145. }
  146. tests := func(h string, a, b byte) {
  147. for _, sz := range []int{0, 16, 32, 64, 79} {
  148. q := strings.Repeat(" ", sz) + h
  149. for sz := 0; sz < 32; sz++ {
  150. test([]byte(q), a, b, sz)
  151. }
  152. }
  153. }
  154. test(nil, '<', '>', 1)
  155. test([]byte{}, '<', '>', 1)
  156. tests("", '<', '>')
  157. tests("a", 0, 0)
  158. tests("a", '<', '>')
  159. tests("dsdfsfa", '1', 'a')
  160. tests("xa", 'a', 'a')
  161. tests("bbb", 'a', '1')
  162. tests("bba", 'a', '<')
  163. tests("baa", '>', 'a')
  164. c0test := func(haystack []byte) {
  165. var actual int
  166. safe_haystack := append(bytes.Repeat([]byte{'<'}, 64), haystack...)
  167. safe_haystack = append(safe_haystack, bytes.Repeat([]byte{'>'}, 64)...)
  168. haystack = safe_haystack[64 : 64+len(haystack)]
  169. expected := index_c0_scalar(haystack)
  170. for _, sz := range sizes {
  171. switch sz {
  172. case 16:
  173. actual = index_c0_asm_128(haystack)
  174. case 32:
  175. actual = index_c0_asm_256(haystack)
  176. }
  177. if actual != expected {
  178. t.Fatalf("C0 char index failed in: %#v (%d != %d) at size: %d", string(haystack), expected, actual, sz)
  179. }
  180. }
  181. }
  182. c0tests := func(h string) {
  183. c0test([]byte(h))
  184. for _, sz := range []int{16, 32, 64, 79} {
  185. q := strings.Repeat(" ", sz) + h
  186. c0test([]byte(q))
  187. }
  188. }
  189. c0tests("a\nfgdfgd\r")
  190. c0tests("")
  191. c0tests("abcdef")
  192. c0tests("afsgdfg\x7f")
  193. c0tests("afgd\x1bfgd\t")
  194. c0tests("a\x00")
  195. index_test := func(haystack []byte, needle byte) {
  196. var actual int
  197. expected := index_byte_scalar(haystack, needle)
  198. for _, sz := range sizes {
  199. switch sz {
  200. case 16:
  201. actual = index_byte_asm_128(haystack, needle)
  202. case 32:
  203. actual = index_byte_asm_256(haystack, needle)
  204. }
  205. if actual != expected {
  206. t.Fatalf("index failed in: %#v (%d != %d) at size: %d with needle: %#v", string(haystack), expected, actual, sz, needle)
  207. }
  208. }
  209. }
  210. index_test([]byte("abc"), 'x')
  211. index_test([]byte("abc"), 'b')
  212. }
  213. func TestIntrinsics(t *testing.T) {
  214. switch runtime.GOARCH {
  215. case "amd64":
  216. if !HasSIMD128Code {
  217. t.Fatal("SIMD 128bit code not built")
  218. }
  219. if !HasSIMD256Code {
  220. t.Fatal("SIMD 256bit code not built")
  221. }
  222. case "arm64":
  223. if !HasSIMD128Code {
  224. t.Fatal("SIMD 128bit code not built")
  225. }
  226. if !Have128bit {
  227. t.Fatal("SIMD 128bit support not available at runtime")
  228. }
  229. }
  230. ae := func(sz int, func_name string, a, b any) {
  231. if s := cmp.Diff(a, b); s != "" {
  232. t.Fatalf("%s failed with size: %d\n%s", func_name, sz, s)
  233. }
  234. }
  235. tests := []func(int){}
  236. tests = append(tests, func(sz int) {
  237. a := ordered_bytes(sz)
  238. ae(sz, `load_test`, a, test_load(a))
  239. })
  240. tests = append(tests, func(sz int) {
  241. for _, b := range []byte{1, 0b110111, 0xff, 0, ' '} {
  242. ae(sz, `set1_epi8_test`, broadcast_byte(b, sz), test_set1_epi8(b, sz))
  243. }
  244. ae(sz, `set1_epi8_test`, broadcast_byte(0xff, sz), test_set1_epi8(11, sz))
  245. })
  246. tests = append(tests, func(sz int) {
  247. a := ordered_bytes(sz)
  248. b := ordered_bytes(sz)
  249. ans := test_cmpeq_epi8(a, b)
  250. ae(sz, `cmpeq_epi8_test`, broadcast_byte(0xff, sz), ans)
  251. lt := func(a, b []byte) []byte {
  252. ans := make([]byte, len(a))
  253. for i := range ans {
  254. if int8(a[i]) < int8(b[i]) {
  255. ans[i] = 0xff
  256. }
  257. }
  258. return ans
  259. }
  260. ae(sz, "cmplt_epi8_test with equal vecs of non-negative numbers", lt(a, b), test_cmplt_epi8(t, a, b))
  261. a = broadcast_byte(1, sz)
  262. b = broadcast_byte(2, sz)
  263. ae(sz, "cmplt_epi8_test with 1 and 2", lt(a, b), test_cmplt_epi8(t, a, b))
  264. ae(sz, "cmplt_epi8_test with 2 and 1", lt(b, a), test_cmplt_epi8(t, b, a))
  265. a = broadcast_byte(0xff, sz)
  266. b = broadcast_byte(0, sz)
  267. ae(sz, "cmplt_epi8_test with -1 and 0", lt(a, b), test_cmplt_epi8(t, a, b))
  268. })
  269. tests = append(tests, func(sz int) {
  270. a := make([]byte, sz)
  271. b := make([]byte, sz)
  272. c := make([]byte, sz)
  273. a[0] = 0xff
  274. b[0] = 0xff
  275. b[1] = 0xff
  276. a[sz-1] = 1
  277. b[sz-1] = 2
  278. for i := range c {
  279. c[i] = a[i] | b[i]
  280. }
  281. ans := test_or(a, b)
  282. ae(sz, `or_test`, c, ans)
  283. })
  284. tests = append(tests, func(sz int) {
  285. a := make([]byte, sz)
  286. if e := test_jump_if_zero(a); e != 0 {
  287. t.Fatalf("Did not detect zero register")
  288. }
  289. for i := 0; i < sz; i++ {
  290. a = make([]byte, sz)
  291. a[i] = 1
  292. if e := test_jump_if_zero(a); e != 1 {
  293. t.Fatalf("Did not detect non-zero register")
  294. }
  295. }
  296. })
  297. tests = append(tests, func(sz int) {
  298. a := ordered_bytes(sz)
  299. if e := test_count_to_match(a, 77); e != -1 {
  300. t.Fatalf("Unexpectedly found byte at: %d", e)
  301. }
  302. for i := 0; i < sz; i++ {
  303. if e := test_count_to_match(a, byte(i)); e != i {
  304. t.Fatalf("Failed to find the byte: %d (%d != %d)", i, i, e)
  305. }
  306. }
  307. a[7] = 0x34
  308. if e := test_count_to_match(a, 0x34); e != 7 {
  309. t.Fatalf("Failed to find the byte: %d (%d != %d)", 0x34, 7, e)
  310. }
  311. })
  312. sizes := get_sizes(t)
  313. for _, sz := range sizes {
  314. for _, test := range tests {
  315. test(sz)
  316. }
  317. }
  318. }