vfloat16_avx512.h 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. // Copyright 2009-2021 Intel Corporation
  2. // SPDX-License-Identifier: Apache-2.0
  3. #pragma once
  4. #define vboolf vboolf_impl
  5. #define vboold vboold_impl
  6. #define vint vint_impl
  7. #define vuint vuint_impl
  8. #define vllong vllong_impl
  9. #define vfloat vfloat_impl
  10. #define vdouble vdouble_impl
  11. namespace embree
  12. {
  13. /* 16-wide AVX-512 float type */
  14. template<>
  15. struct vfloat<16>
  16. {
  17. ALIGNED_STRUCT_(64);
  18. typedef vboolf16 Bool;
  19. typedef vint16 Int;
  20. typedef vfloat16 Float;
  21. enum { size = 16 }; // number of SIMD elements
  22. union { // data
  23. __m512 v;
  24. float f[16];
  25. int i[16];
  26. };
  27. ////////////////////////////////////////////////////////////////////////////////
  28. /// Constructors, Assignment & Cast Operators
  29. ////////////////////////////////////////////////////////////////////////////////
  30. __forceinline vfloat() {}
  31. __forceinline vfloat(const vfloat16& t) { v = t; }
  32. __forceinline vfloat16& operator =(const vfloat16& f) { v = f.v; return *this; }
  33. __forceinline vfloat(const __m512& t) { v = t; }
  34. __forceinline operator __m512() const { return v; }
  35. __forceinline operator __m256() const { return _mm512_castps512_ps256(v); }
  36. __forceinline operator __m128() const { return _mm512_castps512_ps128(v); }
  37. __forceinline vfloat(float f) {
  38. v = _mm512_set1_ps(f);
  39. }
  40. __forceinline vfloat(float a, float b, float c, float d) {
  41. v = _mm512_set4_ps(a, b, c, d);
  42. }
  43. __forceinline vfloat(const vfloat4& i) {
  44. v = _mm512_broadcast_f32x4(i);
  45. }
  46. __forceinline vfloat(const vfloat4& a, const vfloat4& b, const vfloat4& c, const vfloat4& d) {
  47. v = _mm512_castps128_ps512(a);
  48. v = _mm512_insertf32x4(v, b, 1);
  49. v = _mm512_insertf32x4(v, c, 2);
  50. v = _mm512_insertf32x4(v, d, 3);
  51. }
  52. __forceinline vfloat(const vboolf16& mask, const vfloat4& a, const vfloat4& b) {
  53. v = _mm512_broadcast_f32x4(a);
  54. v = _mm512_mask_broadcast_f32x4(v,mask,b);
  55. }
  56. __forceinline vfloat(const vfloat8& i) {
  57. v = _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castps_pd(i)));
  58. }
  59. __forceinline vfloat(const vfloat8& a, const vfloat8& b) {
  60. v = _mm512_castps256_ps512(a);
  61. #if defined(__AVX512DQ__)
  62. v = _mm512_insertf32x8(v, b, 1);
  63. #else
  64. v = _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castps_pd(v), _mm256_castps_pd(b), 1));
  65. #endif
  66. }
  67. /* WARNING: due to f64x4 the mask is considered as an 8bit mask */
  68. /*__forceinline vfloat(const vboolf16& mask, const vfloat8& a, const vfloat8& b) {
  69. __m512d aa = _mm512_broadcast_f64x4(_mm256_castps_pd(a));
  70. aa = _mm512_mask_broadcast_f64x4(aa,mask,_mm256_castps_pd(b));
  71. v = _mm512_castpd_ps(aa);
  72. }*/
  73. __forceinline explicit vfloat(const vint16& a) {
  74. v = _mm512_cvtepi32_ps(a);
  75. }
  76. __forceinline explicit vfloat(const vuint16& a) {
  77. v = _mm512_cvtepu32_ps(a);
  78. }
  79. ////////////////////////////////////////////////////////////////////////////////
  80. /// Constants
  81. ////////////////////////////////////////////////////////////////////////////////
  82. __forceinline vfloat(ZeroTy) : v(_mm512_setzero_ps()) {}
  83. __forceinline vfloat(OneTy) : v(_mm512_set1_ps(1.0f)) {}
  84. __forceinline vfloat(PosInfTy) : v(_mm512_set1_ps(pos_inf)) {}
  85. __forceinline vfloat(NegInfTy) : v(_mm512_set1_ps(neg_inf)) {}
  86. __forceinline vfloat(StepTy) : v(_mm512_set_ps(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {}
  87. __forceinline vfloat(NaNTy) : v(_mm512_set1_ps(nan)) {}
  88. __forceinline vfloat(UndefinedTy) : v(_mm512_undefined_ps()) {}
  89. ////////////////////////////////////////////////////////////////////////////////
  90. /// Loads and Stores
  91. ////////////////////////////////////////////////////////////////////////////////
  92. static __forceinline vfloat16 load (const void* ptr) { return _mm512_load_ps((float*)ptr); }
  93. static __forceinline vfloat16 loadu(const void* ptr) { return _mm512_loadu_ps((float*)ptr); }
  94. static __forceinline vfloat16 load (const vboolf16& mask, const void* ptr) { return _mm512_mask_load_ps (_mm512_setzero_ps(),mask,(float*)ptr); }
  95. static __forceinline vfloat16 loadu(const vboolf16& mask, const void* ptr) { return _mm512_mask_loadu_ps(_mm512_setzero_ps(),mask,(float*)ptr); }
  96. static __forceinline void store (void* ptr, const vfloat16& v) { _mm512_store_ps ((float*)ptr,v); }
  97. static __forceinline void storeu(void* ptr, const vfloat16& v) { _mm512_storeu_ps((float*)ptr,v); }
  98. static __forceinline void store (const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_store_ps ((float*)ptr,mask,v); }
  99. static __forceinline void storeu(const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_storeu_ps((float*)ptr,mask,v); }
  100. static __forceinline void store_nt(void* __restrict__ ptr, const vfloat16& a) {
  101. _mm512_stream_ps((float*)ptr,a);
  102. }
  103. static __forceinline vfloat16 broadcast(const float* f) {
  104. return _mm512_set1_ps(*f);
  105. }
  106. template<int scale = 4>
  107. static __forceinline vfloat16 gather(const float* ptr, const vint16& index) {
  108. return _mm512_i32gather_ps(index, ptr, scale);
  109. }
  110. template<int scale = 4>
  111. static __forceinline vfloat16 gather(const vboolf16& mask, const float* ptr, const vint16& index) {
  112. vfloat16 r = zero;
  113. return _mm512_mask_i32gather_ps(r, mask, index, ptr, scale);
  114. }
  115. template<int scale = 4>
  116. static __forceinline void scatter(float* ptr, const vint16& index, const vfloat16& v) {
  117. _mm512_i32scatter_ps(ptr, index, v, scale);
  118. }
  119. template<int scale = 4>
  120. static __forceinline void scatter(const vboolf16& mask, float* ptr, const vint16& index, const vfloat16& v) {
  121. _mm512_mask_i32scatter_ps(ptr, mask, index, v, scale);
  122. }
  123. ////////////////////////////////////////////////////////////////////////////////
  124. /// Array Access
  125. ////////////////////////////////////////////////////////////////////////////////
  126. __forceinline float& operator [](size_t index) { assert(index < 16); return f[index]; }
  127. __forceinline const float& operator [](size_t index) const { assert(index < 16); return f[index]; }
  128. };
  129. ////////////////////////////////////////////////////////////////////////////////
  130. /// Unary Operators
  131. ////////////////////////////////////////////////////////////////////////////////
  132. __forceinline vfloat16 asFloat(const vint16& a) { return _mm512_castsi512_ps(a); }
  133. __forceinline vint16 asInt (const vfloat16& a) { return _mm512_castps_si512(a); }
  134. __forceinline vuint16 asUInt (const vfloat16& a) { return _mm512_castps_si512(a); }
  135. __forceinline vint16 toInt (const vfloat16& a) { return vint16(a); }
  136. __forceinline vfloat16 toFloat(const vint16& a) { return vfloat16(a); }
  137. __forceinline vfloat16 operator +(const vfloat16& a) { return a; }
  138. __forceinline vfloat16 operator -(const vfloat16& a) { return _mm512_mul_ps(a,vfloat16(-1)); }
  139. __forceinline vfloat16 abs (const vfloat16& a) { return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a),_mm512_set1_epi32(0x7FFFFFFF))); }
  140. __forceinline vfloat16 signmsk(const vfloat16& a) { return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a),_mm512_set1_epi32(0x80000000))); }
  141. __forceinline vfloat16 rcp(const vfloat16& a)
  142. {
  143. const vfloat16 r = _mm512_rcp14_ps(a);
  144. return _mm512_fmadd_ps(r, _mm512_fnmadd_ps(a, r, vfloat16(1.0)), r); // computes r + r * (1 - a*r)
  145. }
  146. __forceinline vfloat16 sqr (const vfloat16& a) { return _mm512_mul_ps(a,a); }
  147. __forceinline vfloat16 sqrt(const vfloat16& a) { return _mm512_sqrt_ps(a); }
  148. __forceinline vfloat16 rsqrt(const vfloat16& a)
  149. {
  150. const vfloat16 r = _mm512_rsqrt14_ps(a);
  151. return _mm512_fmadd_ps(_mm512_set1_ps(1.5f), r,
  152. _mm512_mul_ps(_mm512_mul_ps(_mm512_mul_ps(a, _mm512_set1_ps(-0.5f)), r), _mm512_mul_ps(r, r)));
  153. }
  154. ////////////////////////////////////////////////////////////////////////////////
  155. /// Binary Operators
  156. ////////////////////////////////////////////////////////////////////////////////
  157. __forceinline vfloat16 operator +(const vfloat16& a, const vfloat16& b) { return _mm512_add_ps(a, b); }
  158. __forceinline vfloat16 operator +(const vfloat16& a, float b) { return a + vfloat16(b); }
  159. __forceinline vfloat16 operator +(float a, const vfloat16& b) { return vfloat16(a) + b; }
  160. __forceinline vfloat16 operator -(const vfloat16& a, const vfloat16& b) { return _mm512_sub_ps(a, b); }
  161. __forceinline vfloat16 operator -(const vfloat16& a, float b) { return a - vfloat16(b); }
  162. __forceinline vfloat16 operator -(float a, const vfloat16& b) { return vfloat16(a) - b; }
  163. __forceinline vfloat16 operator *(const vfloat16& a, const vfloat16& b) { return _mm512_mul_ps(a, b); }
  164. __forceinline vfloat16 operator *(const vfloat16& a, float b) { return a * vfloat16(b); }
  165. __forceinline vfloat16 operator *(float a, const vfloat16& b) { return vfloat16(a) * b; }
  166. __forceinline vfloat16 operator /(const vfloat16& a, const vfloat16& b) { return _mm512_div_ps(a,b); }
  167. __forceinline vfloat16 operator /(const vfloat16& a, float b) { return a/vfloat16(b); }
  168. __forceinline vfloat16 operator /(float a, const vfloat16& b) { return vfloat16(a)/b; }
  169. __forceinline vfloat16 operator &(const vfloat16& a, const vfloat16& b) { return _mm512_and_ps(a,b); }
  170. __forceinline vfloat16 operator |(const vfloat16& a, const vfloat16& b) { return _mm512_or_ps(a,b); }
  171. __forceinline vfloat16 operator ^(const vfloat16& a, const vfloat16& b) {
  172. return _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b)));
  173. }
  174. __forceinline vfloat16 min(const vfloat16& a, const vfloat16& b) { return _mm512_min_ps(a,b); }
  175. __forceinline vfloat16 min(const vfloat16& a, float b) { return _mm512_min_ps(a,vfloat16(b)); }
  176. __forceinline vfloat16 min(const float& a, const vfloat16& b) { return _mm512_min_ps(vfloat16(a),b); }
  177. __forceinline vfloat16 max(const vfloat16& a, const vfloat16& b) { return _mm512_max_ps(a,b); }
  178. __forceinline vfloat16 max(const vfloat16& a, float b) { return _mm512_max_ps(a,vfloat16(b)); }
  179. __forceinline vfloat16 max(const float& a, const vfloat16& b) { return _mm512_max_ps(vfloat16(a),b); }
  180. __forceinline vfloat16 mini(const vfloat16& a, const vfloat16& b) {
  181. const vint16 ai = _mm512_castps_si512(a);
  182. const vint16 bi = _mm512_castps_si512(b);
  183. const vint16 ci = _mm512_min_epi32(ai,bi);
  184. return _mm512_castsi512_ps(ci);
  185. }
  186. __forceinline vfloat16 maxi(const vfloat16& a, const vfloat16& b) {
  187. const vint16 ai = _mm512_castps_si512(a);
  188. const vint16 bi = _mm512_castps_si512(b);
  189. const vint16 ci = _mm512_max_epi32(ai,bi);
  190. return _mm512_castsi512_ps(ci);
  191. }
  192. ////////////////////////////////////////////////////////////////////////////////
  193. /// Ternary Operators
  194. ////////////////////////////////////////////////////////////////////////////////
  195. __forceinline vfloat16 madd (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_ps(a,b,c); }
  196. __forceinline vfloat16 msub (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(a,b,c); }
  197. __forceinline vfloat16 nmadd(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmadd_ps(a,b,c); }
  198. __forceinline vfloat16 nmsub(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmsub_ps(a,b,c); }
  199. ////////////////////////////////////////////////////////////////////////////////
  200. /// Assignment Operators
  201. ////////////////////////////////////////////////////////////////////////////////
  202. __forceinline vfloat16& operator +=(vfloat16& a, const vfloat16& b) { return a = a + b; }
  203. __forceinline vfloat16& operator +=(vfloat16& a, float b) { return a = a + b; }
  204. __forceinline vfloat16& operator -=(vfloat16& a, const vfloat16& b) { return a = a - b; }
  205. __forceinline vfloat16& operator -=(vfloat16& a, float b) { return a = a - b; }
  206. __forceinline vfloat16& operator *=(vfloat16& a, const vfloat16& b) { return a = a * b; }
  207. __forceinline vfloat16& operator *=(vfloat16& a, float b) { return a = a * b; }
  208. __forceinline vfloat16& operator /=(vfloat16& a, const vfloat16& b) { return a = a / b; }
  209. __forceinline vfloat16& operator /=(vfloat16& a, float b) { return a = a / b; }
  210. ////////////////////////////////////////////////////////////////////////////////
  211. /// Comparison Operators + Select
  212. ////////////////////////////////////////////////////////////////////////////////
  213. __forceinline vboolf16 operator ==(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); }
  214. __forceinline vboolf16 operator ==(const vfloat16& a, float b) { return a == vfloat16(b); }
  215. __forceinline vboolf16 operator ==(float a, const vfloat16& b) { return vfloat16(a) == b; }
  216. __forceinline vboolf16 operator !=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); }
  217. __forceinline vboolf16 operator !=(const vfloat16& a, float b) { return a != vfloat16(b); }
  218. __forceinline vboolf16 operator !=(float a, const vfloat16& b) { return vfloat16(a) != b; }
  219. __forceinline vboolf16 operator < (const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); }
  220. __forceinline vboolf16 operator < (const vfloat16& a, float b) { return a < vfloat16(b); }
  221. __forceinline vboolf16 operator < (float a, const vfloat16& b) { return vfloat16(a) < b; }
  222. __forceinline vboolf16 operator >=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); }
  223. __forceinline vboolf16 operator >=(const vfloat16& a, float b) { return a >= vfloat16(b); }
  224. __forceinline vboolf16 operator >=(float a, const vfloat16& b) { return vfloat16(a) >= b; }
  225. __forceinline vboolf16 operator > (const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); }
  226. __forceinline vboolf16 operator > (const vfloat16& a, float b) { return a > vfloat16(b); }
  227. __forceinline vboolf16 operator > (float a, const vfloat16& b) { return vfloat16(a) > b; }
  228. __forceinline vboolf16 operator <=(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); }
  229. __forceinline vboolf16 operator <=(const vfloat16& a, float b) { return a <= vfloat16(b); }
  230. __forceinline vboolf16 operator <=(float a, const vfloat16& b) { return vfloat16(a) <= b; }
  231. __forceinline vboolf16 eq(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); }
  232. __forceinline vboolf16 ne(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); }
  233. __forceinline vboolf16 lt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); }
  234. __forceinline vboolf16 ge(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); }
  235. __forceinline vboolf16 gt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); }
  236. __forceinline vboolf16 le(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); }
  237. __forceinline vboolf16 eq(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_EQ); }
  238. __forceinline vboolf16 ne(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_NE); }
  239. __forceinline vboolf16 lt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LT); }
  240. __forceinline vboolf16 ge(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GE); }
  241. __forceinline vboolf16 gt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GT); }
  242. __forceinline vboolf16 le(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LE); }
  243. __forceinline vfloat16 select(const vboolf16& s, const vfloat16& t, const vfloat16& f) {
  244. return _mm512_mask_blend_ps(s, f, t);
  245. }
  246. __forceinline vfloat16 lerp(const vfloat16& a, const vfloat16& b, const vfloat16& t) {
  247. return madd(t,b-a,a);
  248. }
  249. ////////////////////////////////////////////////////////////////////////////////
  250. /// Rounding Functions
  251. ////////////////////////////////////////////////////////////////////////////////
  252. __forceinline vfloat16 floor(const vfloat16& a) {
  253. return _mm512_floor_ps(a);
  254. }
  255. __forceinline vfloat16 ceil (const vfloat16& a) {
  256. return _mm512_ceil_ps(a);
  257. }
  258. __forceinline vfloat16 round (const vfloat16& a) {
  259. return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
  260. }
  261. __forceinline vint16 floori (const vfloat16& a) {
  262. return _mm512_cvt_roundps_epi32(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
  263. }
  264. ////////////////////////////////////////////////////////////////////////////////
  265. /// Movement/Shifting/Shuffling Functions
  266. ////////////////////////////////////////////////////////////////////////////////
  267. __forceinline vfloat16 unpacklo(const vfloat16& a, const vfloat16& b) { return _mm512_unpacklo_ps(a, b); }
  268. __forceinline vfloat16 unpackhi(const vfloat16& a, const vfloat16& b) { return _mm512_unpackhi_ps(a, b); }
  269. template<int i>
  270. __forceinline vfloat16 shuffle(const vfloat16& v) {
  271. return _mm512_permute_ps(v, _MM_SHUFFLE(i, i, i, i));
  272. }
  273. template<int i0, int i1, int i2, int i3>
  274. __forceinline vfloat16 shuffle(const vfloat16& v) {
  275. return _mm512_permute_ps(v, _MM_SHUFFLE(i3, i2, i1, i0));
  276. }
  277. template<int i>
  278. __forceinline vfloat16 shuffle4(const vfloat16& v) {
  279. return _mm512_shuffle_f32x4(v, v ,_MM_SHUFFLE(i, i, i, i));
  280. }
  281. template<int i0, int i1, int i2, int i3>
  282. __forceinline vfloat16 shuffle4(const vfloat16& v) {
  283. return _mm512_shuffle_f32x4(v, v, _MM_SHUFFLE(i3, i2, i1, i0));
  284. }
  285. __forceinline vfloat16 interleave4_even(const vfloat16& a, const vfloat16& b) {
  286. return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(a), mm512_int2mask(0xcc), _mm512_castps_si512(b), (_MM_PERM_ENUM)0x4e));
  287. }
  288. __forceinline vfloat16 interleave4_odd(const vfloat16& a, const vfloat16& b) {
  289. return _mm512_castsi512_ps(_mm512_mask_permutex_epi64(_mm512_castps_si512(b), mm512_int2mask(0x33), _mm512_castps_si512(a), (_MM_PERM_ENUM)0x4e));
  290. }
  291. __forceinline vfloat16 permute(vfloat16 v, __m512i index) {
  292. return _mm512_castsi512_ps(_mm512_permutexvar_epi32(index, _mm512_castps_si512(v)));
  293. }
  294. __forceinline vfloat16 reverse(const vfloat16& v) {
  295. return permute(v,_mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
  296. }
  297. template<int i>
  298. __forceinline vfloat16 align_shift_right(const vfloat16& a, const vfloat16& b) {
  299. return _mm512_castsi512_ps(_mm512_alignr_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b),i));
  300. };
  301. template<int i>
  302. __forceinline vfloat16 mask_align_shift_right(const vboolf16& mask, vfloat16& c, const vfloat16& a, const vfloat16& b) {
  303. return _mm512_castsi512_ps(_mm512_mask_alignr_epi32(_mm512_castps_si512(c),mask,_mm512_castps_si512(a),_mm512_castps_si512(b),i));
  304. };
  305. __forceinline vfloat16 shift_left_1(const vfloat16& a) {
  306. vfloat16 z = zero;
  307. return mask_align_shift_right<15>(0xfffe,z,a,a);
  308. }
  309. __forceinline vfloat16 shift_right_1(const vfloat16& x) {
  310. return align_shift_right<1>(zero,x);
  311. }
  312. __forceinline float toScalar(const vfloat16& v) { return mm512_cvtss_f32(v); }
  313. template<int i> __forceinline vfloat16 insert4(const vfloat16& a, const vfloat4& b) { return _mm512_insertf32x4(a, b, i); }
  314. template<int N, int i>
  315. vfloat<N> extractN(const vfloat16& v);
  316. template<> __forceinline vfloat4 extractN<4,0>(const vfloat16& v) { return _mm512_castps512_ps128(v); }
  317. template<> __forceinline vfloat4 extractN<4,1>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 1); }
  318. template<> __forceinline vfloat4 extractN<4,2>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 2); }
  319. template<> __forceinline vfloat4 extractN<4,3>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 3); }
  320. template<> __forceinline vfloat8 extractN<8,0>(const vfloat16& v) { return _mm512_castps512_ps256(v); }
  321. template<> __forceinline vfloat8 extractN<8,1>(const vfloat16& v) { return _mm512_extractf32x8_ps(v, 1); }
  322. template<int i> __forceinline vfloat4 extract4 (const vfloat16& v) { return _mm512_extractf32x4_ps(v, i); }
  323. template<> __forceinline vfloat4 extract4<0>(const vfloat16& v) { return _mm512_castps512_ps128(v); }
  324. template<int i> __forceinline vfloat8 extract8 (const vfloat16& v) { return _mm512_extractf32x8_ps(v, i); }
  325. template<> __forceinline vfloat8 extract8<0>(const vfloat16& v) { return _mm512_castps512_ps256(v); }
  326. ////////////////////////////////////////////////////////////////////////////////
  327. /// Transpose
  328. ////////////////////////////////////////////////////////////////////////////////
  329. __forceinline void transpose(const vfloat16& r0, const vfloat16& r1, const vfloat16& r2, const vfloat16& r3,
  330. vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3)
  331. {
  332. vfloat16 a0a2_b0b2 = unpacklo(r0, r2);
  333. vfloat16 c0c2_d0d2 = unpackhi(r0, r2);
  334. vfloat16 a1a3_b1b3 = unpacklo(r1, r3);
  335. vfloat16 c1c3_d1d3 = unpackhi(r1, r3);
  336. c0 = unpacklo(a0a2_b0b2, a1a3_b1b3);
  337. c1 = unpackhi(a0a2_b0b2, a1a3_b1b3);
  338. c2 = unpacklo(c0c2_d0d2, c1c3_d1d3);
  339. c3 = unpackhi(c0c2_d0d2, c1c3_d1d3);
  340. }
  341. __forceinline void transpose(const vfloat4& r0, const vfloat4& r1, const vfloat4& r2, const vfloat4& r3,
  342. const vfloat4& r4, const vfloat4& r5, const vfloat4& r6, const vfloat4& r7,
  343. const vfloat4& r8, const vfloat4& r9, const vfloat4& r10, const vfloat4& r11,
  344. const vfloat4& r12, const vfloat4& r13, const vfloat4& r14, const vfloat4& r15,
  345. vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3)
  346. {
  347. return transpose(vfloat16(r0, r4, r8, r12), vfloat16(r1, r5, r9, r13), vfloat16(r2, r6, r10, r14), vfloat16(r3, r7, r11, r15),
  348. c0, c1, c2, c3);
  349. }
  350. __forceinline void transpose(const vfloat16& r0, const vfloat16& r1, const vfloat16& r2, const vfloat16& r3,
  351. const vfloat16& r4, const vfloat16& r5, const vfloat16& r6, const vfloat16& r7,
  352. vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3,
  353. vfloat16& c4, vfloat16& c5, vfloat16& c6, vfloat16& c7)
  354. {
  355. vfloat16 a0a1a2a3_e0e1e2e3, b0b1b2b3_f0f1f2f3, c0c1c2c3_g0g1g2g3, d0d1d2d3_h0h1h2h3;
  356. transpose(r0, r1, r2, r3, a0a1a2a3_e0e1e2e3, b0b1b2b3_f0f1f2f3, c0c1c2c3_g0g1g2g3, d0d1d2d3_h0h1h2h3);
  357. vfloat16 a4a5a6a7_e4e5e6e7, b4b5b6b7_f4f5f6f7, c4c5c6c7_g4g5g6g7, d4d5d6d7_h4h5h6h7;
  358. transpose(r4, r5, r6, r7, a4a5a6a7_e4e5e6e7, b4b5b6b7_f4f5f6f7, c4c5c6c7_g4g5g6g7, d4d5d6d7_h4h5h6h7);
  359. c0 = interleave4_even(a0a1a2a3_e0e1e2e3, a4a5a6a7_e4e5e6e7);
  360. c1 = interleave4_even(b0b1b2b3_f0f1f2f3, b4b5b6b7_f4f5f6f7);
  361. c2 = interleave4_even(c0c1c2c3_g0g1g2g3, c4c5c6c7_g4g5g6g7);
  362. c3 = interleave4_even(d0d1d2d3_h0h1h2h3, d4d5d6d7_h4h5h6h7);
  363. c4 = interleave4_odd (a0a1a2a3_e0e1e2e3, a4a5a6a7_e4e5e6e7);
  364. c5 = interleave4_odd (b0b1b2b3_f0f1f2f3, b4b5b6b7_f4f5f6f7);
  365. c6 = interleave4_odd (c0c1c2c3_g0g1g2g3, c4c5c6c7_g4g5g6g7);
  366. c7 = interleave4_odd (d0d1d2d3_h0h1h2h3, d4d5d6d7_h4h5h6h7);
  367. }
  368. __forceinline void transpose(const vfloat8& r0, const vfloat8& r1, const vfloat8& r2, const vfloat8& r3,
  369. const vfloat8& r4, const vfloat8& r5, const vfloat8& r6, const vfloat8& r7,
  370. const vfloat8& r8, const vfloat8& r9, const vfloat8& r10, const vfloat8& r11,
  371. const vfloat8& r12, const vfloat8& r13, const vfloat8& r14, const vfloat8& r15,
  372. vfloat16& c0, vfloat16& c1, vfloat16& c2, vfloat16& c3,
  373. vfloat16& c4, vfloat16& c5, vfloat16& c6, vfloat16& c7)
  374. {
  375. return transpose(vfloat16(r0, r8), vfloat16(r1, r9), vfloat16(r2, r10), vfloat16(r3, r11),
  376. vfloat16(r4, r12), vfloat16(r5, r13), vfloat16(r6, r14), vfloat16(r7, r15),
  377. c0, c1, c2, c3, c4, c5, c6, c7);
  378. }
  379. ////////////////////////////////////////////////////////////////////////////////
  380. /// Reductions
  381. ////////////////////////////////////////////////////////////////////////////////
  382. __forceinline vfloat16 vreduce_add2(vfloat16 x) { return x + shuffle<1,0,3,2>(x); }
  383. __forceinline vfloat16 vreduce_add4(vfloat16 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); }
  384. __forceinline vfloat16 vreduce_add8(vfloat16 x) { x = vreduce_add4(x); return x + shuffle4<1,0,3,2>(x); }
  385. __forceinline vfloat16 vreduce_add (vfloat16 x) { x = vreduce_add8(x); return x + shuffle4<2,3,0,1>(x); }
  386. __forceinline vfloat16 vreduce_min2(vfloat16 x) { return min(x, shuffle<1,0,3,2>(x)); }
  387. __forceinline vfloat16 vreduce_min4(vfloat16 x) { x = vreduce_min2(x); return min(x, shuffle<2,3,0,1>(x)); }
  388. __forceinline vfloat16 vreduce_min8(vfloat16 x) { x = vreduce_min4(x); return min(x, shuffle4<1,0,3,2>(x)); }
  389. __forceinline vfloat16 vreduce_min (vfloat16 x) { x = vreduce_min8(x); return min(x, shuffle4<2,3,0,1>(x)); }
  390. __forceinline vfloat16 vreduce_max2(vfloat16 x) { return max(x, shuffle<1,0,3,2>(x)); }
  391. __forceinline vfloat16 vreduce_max4(vfloat16 x) { x = vreduce_max2(x); return max(x, shuffle<2,3,0,1>(x)); }
  392. __forceinline vfloat16 vreduce_max8(vfloat16 x) { x = vreduce_max4(x); return max(x, shuffle4<1,0,3,2>(x)); }
  393. __forceinline vfloat16 vreduce_max (vfloat16 x) { x = vreduce_max8(x); return max(x, shuffle4<2,3,0,1>(x)); }
  394. __forceinline float reduce_add(const vfloat16& v) { return toScalar(vreduce_add(v)); }
  395. __forceinline float reduce_min(const vfloat16& v) { return toScalar(vreduce_min(v)); }
  396. __forceinline float reduce_max(const vfloat16& v) { return toScalar(vreduce_max(v)); }
  397. __forceinline size_t select_min(const vfloat16& v) {
  398. return bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_min(v)),_MM_CMPINT_EQ)));
  399. }
  400. __forceinline size_t select_max(const vfloat16& v) {
  401. return bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_max(v)),_MM_CMPINT_EQ)));
  402. }
  403. __forceinline size_t select_min(const vboolf16& valid, const vfloat16& v)
  404. {
  405. const vfloat16 a = select(valid,v,vfloat16(pos_inf));
  406. const vbool16 valid_min = valid & (a == vreduce_min(a));
  407. return bsf(movemask(any(valid_min) ? valid_min : valid));
  408. }
  409. __forceinline size_t select_max(const vboolf16& valid, const vfloat16& v)
  410. {
  411. const vfloat16 a = select(valid,v,vfloat16(neg_inf));
  412. const vbool16 valid_max = valid & (a == vreduce_max(a));
  413. return bsf(movemask(any(valid_max) ? valid_max : valid));
  414. }
  415. __forceinline vfloat16 prefix_sum(const vfloat16& a)
  416. {
  417. const vfloat16 z(zero);
  418. vfloat16 v = a;
  419. v = v + align_shift_right<16-1>(v,z);
  420. v = v + align_shift_right<16-2>(v,z);
  421. v = v + align_shift_right<16-4>(v,z);
  422. v = v + align_shift_right<16-8>(v,z);
  423. return v;
  424. }
  425. __forceinline vfloat16 reverse_prefix_sum(const vfloat16& a)
  426. {
  427. const vfloat16 z(zero);
  428. vfloat16 v = a;
  429. v = v + align_shift_right<1>(z,v);
  430. v = v + align_shift_right<2>(z,v);
  431. v = v + align_shift_right<4>(z,v);
  432. v = v + align_shift_right<8>(z,v);
  433. return v;
  434. }
  435. __forceinline vfloat16 prefix_min(const vfloat16& a)
  436. {
  437. const vfloat16 z(pos_inf);
  438. vfloat16 v = a;
  439. v = min(v,align_shift_right<16-1>(v,z));
  440. v = min(v,align_shift_right<16-2>(v,z));
  441. v = min(v,align_shift_right<16-4>(v,z));
  442. v = min(v,align_shift_right<16-8>(v,z));
  443. return v;
  444. }
  445. __forceinline vfloat16 prefix_max(const vfloat16& a)
  446. {
  447. const vfloat16 z(neg_inf);
  448. vfloat16 v = a;
  449. v = max(v,align_shift_right<16-1>(v,z));
  450. v = max(v,align_shift_right<16-2>(v,z));
  451. v = max(v,align_shift_right<16-4>(v,z));
  452. v = max(v,align_shift_right<16-8>(v,z));
  453. return v;
  454. }
  455. __forceinline vfloat16 reverse_prefix_min(const vfloat16& a)
  456. {
  457. const vfloat16 z(pos_inf);
  458. vfloat16 v = a;
  459. v = min(v,align_shift_right<1>(z,v));
  460. v = min(v,align_shift_right<2>(z,v));
  461. v = min(v,align_shift_right<4>(z,v));
  462. v = min(v,align_shift_right<8>(z,v));
  463. return v;
  464. }
  465. __forceinline vfloat16 reverse_prefix_max(const vfloat16& a)
  466. {
  467. const vfloat16 z(neg_inf);
  468. vfloat16 v = a;
  469. v = max(v,align_shift_right<1>(z,v));
  470. v = max(v,align_shift_right<2>(z,v));
  471. v = max(v,align_shift_right<4>(z,v));
  472. v = max(v,align_shift_right<8>(z,v));
  473. return v;
  474. }
  475. __forceinline vfloat16 rcp_safe(const vfloat16& a) {
  476. return rcp(select(a != vfloat16(zero), a, vfloat16(min_rcp_input)));
  477. }
  478. ////////////////////////////////////////////////////////////////////////////////
  479. /// Output Operators
  480. ////////////////////////////////////////////////////////////////////////////////
  481. __forceinline embree_ostream operator <<(embree_ostream cout, const vfloat16& v)
  482. {
  483. cout << "<" << v[0];
  484. for (int i=1; i<16; i++) cout << ", " << v[i];
  485. cout << ">";
  486. return cout;
  487. }
  488. }
  489. #undef vboolf
  490. #undef vboold
  491. #undef vint
  492. #undef vuint
  493. #undef vllong
  494. #undef vfloat
  495. #undef vdouble