ply.hh 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file ply.hh
  3. /// @brief Traverse (ply) array or array expression or array statement.
  4. // (c) Daniel Llorens - 2013-2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  10. #pragma once
  11. #include "ra/atom.hh"
  12. #include <functional>
  13. #include <iostream>
  14. namespace ra {
  15. // --------------
  16. // Run time order
  17. // --------------
  18. // Traverse array expression looking to ravel the inner loop.
  19. // size(k) has a single value.
  20. // adv(k), stride(k), keep_stride(st, k, l) and flat() are used on all the leaf arguments.
  21. // The strides must give 0 for k>=their own rank, to allow frame matching.
  22. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  23. template <RaIterator A>
  24. inline void
  25. ply_ravel(A && a)
  26. {
  27. rank_t rank = a.rank();
  28. assert(rank>=0); // FIXME see test in [ra40].
  29. rank_t order[rank];
  30. for (rank_t i=0; i<rank; ++i) {
  31. order[i] = rank-1-i;
  32. }
  33. switch (rank) {
  34. case 0: *(a.flat()); return;
  35. case 1: break;
  36. default: // TODO find a decent heuristic
  37. // if (rank>1) {
  38. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  39. // { return a.size(order[i])<a.size(order[j]); });
  40. // }
  41. ;
  42. }
  43. // outermost compact dim.
  44. rank_t * ocd = order;
  45. auto ss = a.size(*ocd);
  46. for (--rank, ++ocd; rank>0 && a.keep_stride(ss, order[0], *ocd); --rank, ++ocd) {
  47. ss *= a.size(*ocd);
  48. }
  49. dim_t sha[rank], ind[rank];
  50. for (int k=0; k<rank; ++k) {
  51. ind[k] = 0;
  52. sha[k] = a.size(ocd[k]);
  53. if (sha[k]==0) { // for the ravelled dimensions ss takes care.
  54. return;
  55. }
  56. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  57. }
  58. // all sub xpr strides advance in compact dims, as they might be different.
  59. auto const ss0 = a.stride(order[0]);
  60. // TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  61. for (;;) {
  62. dim_t s = ss;
  63. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  64. *p;
  65. }
  66. for (int k=0; ; ++k) {
  67. if (k>=rank) {
  68. return;
  69. } else if (ind[k]<sha[k]-1) {
  70. ++ind[k];
  71. a.adv(ocd[k], 1);
  72. break;
  73. } else {
  74. ind[k] = 0;
  75. a.adv(ocd[k], 1-sha[k]);
  76. }
  77. }
  78. }
  79. }
  80. // ------------------
  81. // Compile time order
  82. // ------------------
  83. #ifdef RA_INLINE
  84. #error bad definition
  85. #endif
  86. #define RA_INLINE inline /* __attribute__((always_inline)) inline */
  87. template <class order, int ravel_rank, class A, class S>
  88. RA_INLINE constexpr void
  89. subindex(A & a, dim_t s, S const & ss0)
  90. {
  91. if constexpr (mp::len<order> == ravel_rank) {
  92. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  93. *p;
  94. }
  95. } else if constexpr (mp::len<order> > ravel_rank) {
  96. dim_t size = a.size(mp::first<order>::value); // TODO Precompute these at the top
  97. for (dim_t i=0, iend=size; i<iend; ++i) {
  98. subindex<mp::drop1<order>, ravel_rank>(a, s, ss0);
  99. a.adv(mp::first<order>::value, 1);
  100. }
  101. a.adv(mp::first<order>::value, -size);
  102. } else {
  103. abort();
  104. }
  105. }
  106. // until() converts runtime jj into compile time j. TODO a.adv<k>().
  107. template <class order, int j, class A, class S>
  108. RA_INLINE constexpr void
  109. until(int const jj, A & a, dim_t const s, S const & ss0)
  110. {
  111. if constexpr (mp::len<order> < j) {
  112. assert(0 && "rank too high");
  113. } else if constexpr (mp::len<order> >= j) {
  114. if (jj==j) {
  115. subindex<order, j>(a, s, ss0);
  116. } else {
  117. until<order, j+1>(jj, a, s, ss0);
  118. }
  119. } else {
  120. abort();
  121. }
  122. }
  123. // find outermost compact dim.
  124. template <class A>
  125. constexpr auto
  126. ocd()
  127. {
  128. rank_t const rank = A::rank_s();
  129. auto s = A::size_s(rank-1);
  130. int j = 1;
  131. while (j<rank && A::keep_stride(s, rank-1, rank-1-j)) {
  132. s *= A::size_s(rank-1-j);
  133. ++j;
  134. }
  135. return std::make_tuple(s, j);
  136. };
  137. template <RaIterator A>
  138. RA_INLINE constexpr void
  139. plyf(A && a)
  140. {
  141. constexpr rank_t rank = rank_s<A>();
  142. static_assert(rank>=0, "plyf needs static rank");
  143. if constexpr (rank_s<A>()==0) {
  144. *(a.flat());
  145. } else if constexpr (rank_s<A>()==1) {
  146. subindex<mp::iota<1>, 1>(a, a.size(0), a.stride(0));
  147. // this can only be enabled when f() will be constexpr; static keep_stride implies all else is also static.
  148. // important rank>1 for with static size operands [ra43].
  149. } else if constexpr (rank_s<A>()>1 && requires (dim_t d, rank_t i, rank_t j) { A::keep_stride(d, i, j); }) {
  150. constexpr auto sj = ocd<std::decay_t<A>>();
  151. constexpr auto s = std::get<0>(sj);
  152. constexpr auto j = std::get<1>(sj);
  153. // all sub xpr strides advance in compact dims, as they might be different.
  154. // send with static j. Note that order here is inverse of order.
  155. until<mp::iota<rank_s<A>()>, 0>(j, a, s, a.stride(rank-1));
  156. } else {
  157. // don't bother unrolling.
  158. auto s = a.size(rank-1);
  159. subindex<mp::iota<rank_s<A>()>, 1>(a, s, a.stride(rank-1));
  160. }
  161. }
  162. #undef RA_INLINE
  163. // ---------------------------
  164. // Select best performance (or requirements) for each type.
  165. // ---------------------------
  166. template <RaIterator A>
  167. inline constexpr void
  168. ply(A && a)
  169. {
  170. if constexpr (size_s<A>()==DIM_ANY) {
  171. ply_ravel(std::forward<A>(a));
  172. } else {
  173. plyf(std::forward<A>(a));
  174. }
  175. }
  176. // ---------------------------
  177. // Short-circuiting pliers.
  178. // ---------------------------
  179. // TODO Refactor with ply_ravel. Make exit available to plyf.
  180. // TODO These are reductions. How to do higher rank?
  181. template <RaIterator A, class DEF>
  182. inline auto
  183. ply_ravel_exit(A && a, DEF && def)
  184. {
  185. rank_t rank = a.rank();
  186. assert(rank>=0); // FIXME see test in [ra40].
  187. rank_t order[rank];
  188. for (rank_t i=0; i<rank; ++i) {
  189. order[i] = rank-1-i;
  190. }
  191. switch (rank) {
  192. case 0: {
  193. if (auto what = *(a.flat()); std::get<0>(what)) {
  194. return std::get<1>(what);
  195. }
  196. return def;
  197. }
  198. case 1: break;
  199. default: // TODO find a decent heuristic
  200. // if (rank>1) {
  201. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  202. // { return a.size(order[i])<a.size(order[j]); });
  203. // }
  204. ;
  205. }
  206. // outermost compact dim.
  207. rank_t * ocd = order;
  208. auto ss = a.size(*ocd);
  209. for (--rank, ++ocd; rank>0 && a.keep_stride(ss, order[0], *ocd); --rank, ++ocd) {
  210. ss *= a.size(*ocd);
  211. }
  212. dim_t sha[rank], ind[rank];
  213. for (int k=0; k<rank; ++k) {
  214. ind[k] = 0;
  215. sha[k] = a.size(ocd[k]);
  216. if (sha[k]==0) { // for the ravelled dimensions ss takes care.
  217. return def;
  218. }
  219. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  220. }
  221. // all sub xpr strides advance in compact dims, as they might be different.
  222. auto const ss0 = a.stride(order[0]);
  223. // TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  224. for (;;) {
  225. dim_t s = ss;
  226. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  227. if (auto what = *p; std::get<0>(what)) {
  228. return std::get<1>(what);
  229. }
  230. }
  231. for (int k=0; ; ++k) {
  232. if (k>=rank) {
  233. return def;
  234. } else if (ind[k]<sha[k]-1) {
  235. ++ind[k];
  236. a.adv(ocd[k], 1);
  237. break;
  238. } else {
  239. ind[k] = 0;
  240. a.adv(ocd[k], 1-sha[k]);
  241. }
  242. }
  243. }
  244. }
  245. template <RaIterator A, class DEF>
  246. inline decltype(auto)
  247. early(A && a, DEF && def)
  248. {
  249. return ply_ravel_exit(std::forward<A>(a), std::forward<DEF>(def));
  250. }
  251. } // namespace ra