ply.H 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file ply.H
  3. /// @brief Traverse (ply) array or array expression or array statement.
  4. // (c) Daniel Llorens - 2013-2017
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  10. #pragma once
  11. #include "ra/atom.H"
  12. #include <functional>
  13. namespace ra {
  14. static_assert(std::is_signed_v<rank_t> && std::is_signed_v<dim_t>, "bad rank_t");
  15. // --------------
  16. // Run time order
  17. // --------------
  18. // Traverse array expression looking to ravel the inner loop.
  19. // size(k) has a single value.
  20. // adv(k), stride(k), keep_stride(step, k, l) and flat() are used on all the leaf arguments.
  21. // The strides must give 0 for k>=their own rank, to allow frame matching.
  22. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  23. template <class A> inline
  24. void ply_ravel(A && a)
  25. {
  26. rank_t rank = a.rank();
  27. assert(rank>=0); // FIXME see test in [ra40].
  28. rank_t order[rank];
  29. for (rank_t i=0; i<rank; ++i) {
  30. order[i] = rank-1-i;
  31. }
  32. switch (rank) {
  33. case 0: *(a.flat()); return;
  34. case 1: break;
  35. default: // TODO find a decent heuristic
  36. // if (rank>1) {
  37. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  38. // { return a.size(order[i])<a.size(order[j]); });
  39. // }
  40. ;
  41. }
  42. // find outermost compact dim.
  43. rank_t * ocd = order;
  44. auto ss = a.size(*ocd);
  45. for (--rank, ++ocd; rank>0 && a.keep_stride(ss, order[0], *ocd); --rank, ++ocd) {
  46. ss *= a.size(*ocd);
  47. }
  48. dim_t sha[rank], ind[rank];
  49. for (int k=0; k<rank; ++k) {
  50. ind[k] = 0;
  51. sha[k] = a.size(ocd[k]);
  52. if (sha[k]==0) { // for the ravelled dimensions ss takes care.
  53. return;
  54. }
  55. RA_ASSERT(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  56. }
  57. // all sub xpr strides advance in compact dims, as they might be different.
  58. auto const ss0 = a.stride(order[0]);
  59. // TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  60. for (;;) {
  61. dim_t s = ss;
  62. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  63. *p;
  64. }
  65. for (int k=0; ; ++k) {
  66. if (k>=rank) {
  67. return;
  68. } else if (ind[k]<sha[k]-1) {
  69. ++ind[k];
  70. a.adv(ocd[k], 1);
  71. break;
  72. } else {
  73. ind[k] = 0;
  74. a.adv(ocd[k], 1-sha[k]);
  75. }
  76. }
  77. }
  78. }
  79. // -------------------------
  80. // Compile time order. See bench-dot.C for use. No index version.
  81. // With compile-time recursion by rank, one can use adv<k>, but order must also be compile-time.
  82. // -------------------------
  83. #ifdef RA_INLINE
  84. #error bad definition
  85. #endif
  86. #define RA_INLINE inline /* __attribute__((always_inline)) inline */
  87. template <class order, int ravel_rank, class A, class S> RA_INLINE constexpr
  88. void subindex(A & a, dim_t s, S const & ss0)
  89. {
  90. if constexpr (mp::len<order> == ravel_rank) {
  91. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  92. *p;
  93. }
  94. } else if constexpr (mp::len<order> > ravel_rank) {
  95. dim_t size = a.size(mp::first<order>::value); // TODO Precompute these at the top
  96. for (dim_t i=0, iend=size; i<iend; ++i) {
  97. subindex<mp::drop1<order>, ravel_rank>(a, s, ss0);
  98. a.adv(mp::first<order>::value, 1);
  99. }
  100. a.adv(mp::first<order>::value, -size);
  101. } else {
  102. abort();
  103. }
  104. }
  105. // until() converts runtime jj into compile time j. TODO a.adv<k>().
  106. template <class order, int j, class A, class S> RA_INLINE constexpr
  107. void until(int const jj, A & a, dim_t const s, S const & ss0)
  108. {
  109. if constexpr (mp::len<order> < j) {
  110. assert(0 && "rank too high");
  111. } else if constexpr (mp::len<order> >= j) {
  112. if (jj==j) {
  113. subindex<order, j>(a, s, ss0);
  114. } else {
  115. until<order, j+1>(jj, a, s, ss0);
  116. }
  117. } else {
  118. abort();
  119. }
  120. }
  121. template <class A> RA_INLINE constexpr
  122. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()<=0)>
  123. {
  124. static_assert(std::decay_t<A>::rank_s()==0, "plyf needs static rank");
  125. *(a.flat());
  126. }
  127. template <class A> RA_INLINE constexpr
  128. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()==1)>
  129. {
  130. subindex<mp::iota<1>, 1>(a, a.size(0), a.stride(0));
  131. }
  132. // find the outermost compact dim.
  133. template <class A>
  134. constexpr auto ocd(A && a)
  135. {
  136. rank_t const rank = a.rank();
  137. auto s = a.size(rank-1);
  138. int j = 1;
  139. while (j<rank && a.keep_stride(s, rank-1, rank-1-j)) {
  140. s *= a.size(rank-1-j);
  141. ++j;
  142. }
  143. return std::make_tuple(s, j);
  144. };
  145. template <class A> RA_INLINE constexpr
  146. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()>1)>
  147. {
  148. constexpr rank_t rank = std::decay_t<A>::rank_s();
  149. // this can only be enabled when f() will be constexpr; size_s isn't enough b/c of keep_stride.
  150. // test/concrete.C has a case that shows this.
  151. // cf https://stackoverflow.com/questions/55288555
  152. if constexpr(0 && size_s<A>()>=0) {
  153. constexpr auto sj = ocd(a);
  154. constexpr auto s = std::get<0>(sj);
  155. constexpr auto j = std::get<1>(sj);
  156. // all sub xpr strides advance in compact dims, as they might be different.
  157. // send with static j. Note that order here is inverse of order.
  158. until<mp::iota<std::decay_t<A>::rank_s()>, 0>(j, a, s, a.stride(rank-1));
  159. } else {
  160. // the unrolling above isn't worth it when s, j cannot be constexpr.
  161. auto s = a.size(rank-1);
  162. subindex<mp::iota<std::decay_t<A>::rank_s()>, 1>(a, s, a.stride(rank-1));
  163. }
  164. }
  165. #undef RA_INLINE
  166. // ---------------------------
  167. // Select best performance (or requirements) for each type.
  168. // ---------------------------
  169. template <class A> inline constexpr
  170. std::enable_if_t<(size_s<A>()==DIM_ANY)>
  171. ply(A && a)
  172. {
  173. ply_ravel(std::forward<A>(a));
  174. }
  175. template <class A> inline constexpr
  176. std::enable_if_t<(size_s<A>()!=DIM_ANY)>
  177. ply(A && a)
  178. {
  179. plyf(std::forward<A>(a));
  180. }
  181. // ---------------------------
  182. // Short-circuiting pliers.
  183. // ---------------------------
  184. // TODO Refactor with ply_ravel. Make exit available to plyf.
  185. // TODO These are reductions. How to do higher rank?
  186. template <class A, class DEF> inline
  187. auto ply_ravel_exit(A && a, DEF && def)
  188. {
  189. rank_t rank = a.rank();
  190. assert(rank>=0); // FIXME see test in [ra40].
  191. rank_t order[rank];
  192. for (rank_t i=0; i<rank; ++i) {
  193. order[i] = rank-1-i;
  194. }
  195. switch (rank) {
  196. case 0: {
  197. auto what = *(a.flat());
  198. if (std::get<0>(what)) {
  199. return std::get<1>(what);
  200. }
  201. return def;
  202. }
  203. case 1: break;
  204. default: // TODO find a decent heuristic
  205. // if (rank>1) {
  206. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  207. // { return a.size(order[i])<a.size(order[j]); });
  208. // }
  209. ;
  210. }
  211. // find outermost compact dim.
  212. rank_t * ocd = order;
  213. auto ss = a.size(*ocd);
  214. for (--rank, ++ocd; rank>0 && a.keep_stride(ss, order[0], *ocd); --rank, ++ocd) {
  215. ss *= a.size(*ocd);
  216. }
  217. dim_t sha[rank], ind[rank];
  218. for (int k=0; k<rank; ++k) {
  219. ind[k] = 0;
  220. sha[k] = a.size(ocd[k]);
  221. if (sha[k]==0) { // for the ravelled dimensions ss takes care.
  222. return def;
  223. }
  224. }
  225. // all sub xpr strides advance in compact dims, as they might be different.
  226. auto const ss0 = a.stride(order[0]);
  227. // TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  228. for (;;) {
  229. dim_t s = ss;
  230. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  231. auto what = *p;
  232. if (std::get<0>(what)) {
  233. return std::get<1>(what);
  234. }
  235. }
  236. for (int k=0; ; ++k) {
  237. if (k>=rank) {
  238. return def;
  239. } else if (ind[k]<sha[k]-1) {
  240. ++ind[k];
  241. a.adv(ocd[k], 1);
  242. break;
  243. } else {
  244. ind[k] = 0;
  245. a.adv(ocd[k], 1-sha[k]);
  246. }
  247. }
  248. }
  249. }
  250. template <class A, class DEF> inline decltype(auto)
  251. early(A && a, DEF && def)
  252. {
  253. return ply_ravel_exit(std::forward<A>(a), std::forward<DEF>(def));
  254. }
  255. } // namespace ra