ply.H 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // (c) Daniel Llorens - 2013-2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file ply.H
  7. /// @brief Traverse (ply) array or array expression or array statement.
  8. // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  9. #pragma once
  10. #include "ra/type.H"
  11. #include <functional>
  12. namespace ra {
  13. static_assert(std::is_signed<rank_t>::value && std::is_signed<dim_t>::value, "bad rank_t");
  14. // --------------
  15. // Run time order, two versions.
  16. // --------------
  17. // TODO See ply_ravel() for traversal order.
  18. // TODO A(i0, i1 ...) can be partial-applied as A(i0)(i1 ...) for faster indexing
  19. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  20. template <class A> inline
  21. void ply_index(A && a)
  22. {
  23. rank_t const rank = a.rank();
  24. auto ind(ra_traits<decltype(a.shape())>::make(rank, 0));
  25. dim_t sha[rank];
  26. rank_t order[rank];
  27. for (rank_t k=0; k<rank; ++k) {
  28. order[k] = rank-1-k;
  29. sha[k] = a.size(order[k]);
  30. if (sha[k]==0) {
  31. return;
  32. }
  33. }
  34. for (;;) {
  35. a.at(ind);
  36. for (int k=0; ; ++k) {
  37. if (k==rank) {
  38. return;
  39. } else if (++ind[order[k]]<sha[k]) {
  40. break;
  41. } else {
  42. ind[order[k]] = 0;
  43. }
  44. }
  45. }
  46. }
  47. // Traverse array expression looking to ravel the inner loop.
  48. // size() is only used on the driving argument (largest rank).
  49. // adv(), stride(), keep_stride() and flat() are used on all the leaf arguments. The strides must give 0 for k>=their own rank, to allow frame matching.
  50. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  51. template <class A> inline
  52. void ply_ravel(A && a)
  53. {
  54. static_assert(!has_tensorindex<A>, "bad plier for expr");
  55. rank_t rank = a.rank();
  56. rank_t order[rank];
  57. for (rank_t i=0; i<rank; ++i) {
  58. order[i] = rank-1-i;
  59. }
  60. switch (rank) {
  61. case 0: *(a.flat()); return;
  62. case 1: break;
  63. default: // TODO find a decent heuristic
  64. // if (rank>1) {
  65. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  66. // { return a.size(order[i])<a.size(order[j]); });
  67. // }
  68. ;
  69. }
  70. // find outermost compact dim.
  71. rank_t * ocd = order;
  72. auto ss = a.size(*ocd);
  73. for (--rank, ++ocd; rank>0 && a.keep_stride(ss, order[0], *ocd); --rank, ++ocd) {
  74. ss *= a.size(*ocd);
  75. }
  76. dim_t ind[rank], sha[rank];
  77. for (int k=0; k<rank; ++k) {
  78. ind[k] = 0;
  79. sha[k] = a.size(ocd[k]);
  80. if (sha[k]==0) { // for the ravelled dimensions ss takes care.
  81. return;
  82. }
  83. }
  84. // all sub xpr strides advance in compact dims, as they might be different.
  85. auto const ss0 = a.stride(order[0]);
  86. // TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  87. for (;;) {
  88. dim_t s = ss;
  89. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  90. *p;
  91. }
  92. for (int k=0; ; ++k) {
  93. if (k>=rank) {
  94. return;
  95. } else if (ind[k]<sha[k]-1) {
  96. ++ind[k];
  97. a.adv(ocd[k], 1);
  98. break;
  99. } else {
  100. ind[k] = 0;
  101. a.adv(ocd[k], 1-sha[k]);
  102. }
  103. }
  104. }
  105. }
  106. // -------------------------
  107. // Compile time order. See bench-dot.C for use. Index version.
  108. // -------------------------
  109. template <class order, class A, class S>
  110. std::enable_if_t<mp::len<order> == 0> inline
  111. subindexf(A & a, S & i)
  112. {
  113. a.at(i);
  114. }
  115. template <class order, class A, class S>
  116. std::enable_if_t<(mp::len<order> > 0)> inline
  117. subindexf(A & a, S & i_)
  118. {
  119. dim_t & i = i_[mp::First_<order>::value];
  120. // on every subloop, but not worth caching
  121. dim_t const /* constexpr */ s = a.size(mp::First_<order>::value);
  122. for (i=0; i<s; ++i) {
  123. subindexf<mp::Drop1_<order>>(a, i_);
  124. }
  125. }
  126. template <class A> inline
  127. void plyf_index(A && a)
  128. {
  129. using Shape = std::decay_t<decltype(a.shape())>;
  130. Shape i(ra_traits<Shape>::make(a.rank(), 0));
  131. subindexf<mp::Iota_<std::decay_t<A>::rank_s()>>(a, i); // cf with ply_index() for C order.
  132. }
  133. // -------------------------
  134. // Compile time order. See bench-dot.C for use. No index version.
  135. // With compile-time recursion by rank, one can use adv<k>, but order must also be compile-time.
  136. // -------------------------
  137. #ifdef RA_INLINE
  138. #error bad definition
  139. #endif
  140. #define RA_INLINE inline /* __attribute__((always_inline)) inline */
  141. template <class order, int ravel_rank, class A, class S> RA_INLINE
  142. std::enable_if_t<mp::len<order> == ravel_rank>
  143. subindex(A & a, dim_t s, S const & ss0)
  144. {
  145. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  146. *p;
  147. }
  148. }
  149. template <class order, int ravel_rank, class A, class S> RA_INLINE
  150. std::enable_if_t<(mp::len<order> > ravel_rank)>
  151. subindex(A & a, dim_t const s, S const & ss0)
  152. {
  153. dim_t size = a.size(mp::First_<order>::value); // TODO Precompute these at the top
  154. for (dim_t i=0, iend=size; i<iend; ++i) {
  155. subindex<mp::Drop1_<order>, ravel_rank>(a, s, ss0);
  156. a.adv(mp::First_<order>::value, 1);
  157. }
  158. a.adv(mp::First_<order>::value, -size);
  159. }
  160. // until() converts runtime jj into compile time j. TODO a.adv<k>().
  161. template <class order, int j, class A, class S> RA_INLINE
  162. std::enable_if_t<(mp::len<order> < j)>
  163. until(int const jj, A & a, dim_t const s, S const & ss0)
  164. {
  165. assert(0 && "rank too high");
  166. }
  167. template <class order, int j, class A, class S> RA_INLINE
  168. std::enable_if_t<(mp::len<order> >= j)>
  169. until(int const jj, A & a, dim_t const s, S const & ss0)
  170. {
  171. if (jj==j) {
  172. subindex<order, j>(a, s, ss0);
  173. } else {
  174. until<order, j+1>(jj, a, s, ss0);
  175. }
  176. }
  177. template <class A> RA_INLINE
  178. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()<=0)>
  179. {
  180. static_assert(!has_tensorindex<A>, "bad plier for expr");
  181. static_assert(std::decay_t<A>::rank_s()==0, "plyf needs static rank");
  182. *(a.flat());
  183. }
  184. template <class A> RA_INLINE
  185. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()==1)>
  186. {
  187. static_assert(!has_tensorindex<A>, "bad plier for expr");
  188. subindex<mp::Iota_<1>, 1>(a, a.size(0), a.stride(0));
  189. }
  190. template <class A> RA_INLINE
  191. auto plyf(A && a) -> std::enable_if_t<(std::decay_t<A>::rank_s()>1)>
  192. {
  193. static_assert(!has_tensorindex<A>, "bad plier for expr");
  194. /* constexpr */ rank_t const rank = a.rank();
  195. #if 0 // FIXME both s & j can be constexpr. Try again after gcc 7 (constexpr lambda, for j).
  196. // find the outermost compact dim.
  197. /* constexpr */ auto s = a.size(rank-1);
  198. int j = 1;
  199. while (j<rank && a.keep_stride(s, rank-1, rank-1-j)) {
  200. s *= a.size(rank-1-j);
  201. ++j;
  202. }
  203. // all sub xpr strides advance in compact dims, as they might be different.
  204. // send with static j. Note that order here is inverse of order.
  205. until<mp::Iota_<std::decay_t<A>::rank_s()>, 0>(j, a, s, a.stride(rank-1));
  206. #else
  207. // according to bench-dot.C, the unrolling above isn't worth it :-/ TODO
  208. /* constexpr */ auto s = a.size(rank-1);
  209. subindex<mp::Iota_<std::decay_t<A>::rank_s()>, 1>(a, s, a.stride(rank-1));
  210. #endif
  211. }
  212. #undef RA_INLINE
  213. // ---------------------------
  214. // Select best performance (or requirements) for each type.
  215. // ---------------------------
  216. template <class A> inline
  217. std::enable_if_t<has_tensorindex<A>>
  218. ply(A && a)
  219. {
  220. ply_index(std::forward<A>(a));
  221. }
  222. template <class A> inline
  223. std::enable_if_t<!has_tensorindex<A> && (std::decay_t<A>::size_s()==DIM_ANY)>
  224. ply(A && a)
  225. {
  226. ply_ravel(std::forward<A>(a));
  227. }
  228. template <class A> inline
  229. std::enable_if_t<!has_tensorindex<A> && (std::decay_t<A>::size_s()!=DIM_ANY)>
  230. ply(A && a)
  231. {
  232. plyf(std::forward<A>(a));
  233. }
  234. // ---------------------------
  235. // Short-circuiting pliers. TODO These are reductions. How to do higher rank?
  236. // ---------------------------
  237. // BUG options for ply should be the same as for non-short circuit.
  238. template <class A, class DEF>
  239. auto ply_index_exit(A && a, DEF && def)
  240. {
  241. rank_t const rank = a.rank();
  242. auto ind(ra_traits<decltype(a.shape())>::make(rank, 0));
  243. dim_t sha[rank];
  244. rank_t order[rank];
  245. for (rank_t k=0; k<rank; ++k) {
  246. order[k] = rank-1-k;
  247. sha[k] = a.size(order[k]);
  248. if (sha[k]==0) {
  249. return def;
  250. }
  251. }
  252. for (;;) {
  253. auto what = a.at(ind);
  254. if (std::get<0>(what)) {
  255. return std::get<1>(what);
  256. }
  257. for (int k=0; ; ++k) {
  258. if (k==rank) {
  259. return def;
  260. } else if (++ind[order[k]]<sha[k]) {
  261. break;
  262. } else {
  263. ind[order[k]] = 0;
  264. }
  265. }
  266. }
  267. }
  268. template <class A, class DEF>
  269. auto ply_index_exit_1(A && a, DEF && def)
  270. {
  271. auto s = a.size(0);
  272. auto ss0 = a.stride(0);
  273. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  274. auto what = *p;
  275. if (std::get<0>(what)) {
  276. return std::get<1>(what);
  277. }
  278. }
  279. return def;
  280. }
  281. template <class A, class DEF, std::enable_if_t<is_iterator<A> && (std::decay_t<A>::rank_s()!=1 || has_tensorindex<A>), int> = 0>
  282. inline decltype(auto)
  283. ply_exit(A && a, DEF && def)
  284. {
  285. return ply_index_exit(std::forward<A>(a), std::forward<DEF>(def));
  286. }
  287. template <class A, class DEF, std::enable_if_t<is_iterator<A> && (std::decay_t<A>::rank_s()==1 && !has_tensorindex<A>), int> = 0>
  288. inline decltype(auto)
  289. ply_exit(A && a, DEF && def)
  290. {
  291. return ply_index_exit_1(std::forward<A>(a), std::forward<DEF>(def));
  292. }
  293. template <class A, class DEF> inline decltype(auto)
  294. early(A && a, DEF && def)
  295. {
  296. return ply_exit(std::forward<A>(a), std::forward<DEF>(def));
  297. }
  298. } // namespace ra