wrank.H 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file wrank.H
  3. /// @brief Rank conjunction for expression templates.
  4. // (c) Daniel Llorens - 2013-2017, 2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/expr.H"
  11. #if defined(RA_CHECK_BOUNDS) && RA_CHECK_BOUNDS==0
  12. #define CHECK_BOUNDS( cond )
  13. #else
  14. #define CHECK_BOUNDS( cond ) RA_ASSERT( cond, 0 )
  15. #endif
  16. // TODO Adopt frame matching as in Match (no driver).
  17. namespace ra {
  18. template <class cranks, class Op>
  19. struct Verb
  20. {
  21. using R = cranks;
  22. Op op;
  23. };
  24. template <class cranks, class Op> inline constexpr
  25. auto wrank(cranks cranks_, Op && op)
  26. {
  27. return Verb<cranks, Op> { std::forward<Op>(op) };
  28. }
  29. template <rank_t ... crank, class Op> inline constexpr
  30. auto wrank(Op && op)
  31. {
  32. return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
  33. }
  34. template <class A> using ValidRank = mp::int_t<(A::value>=0)>;
  35. template <class R, class skip, class frank>
  36. using AddFrameAxes = mp::append<R, mp::iota<frank::value, skip::value>>;
  37. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  38. struct Framematch_def;
  39. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  40. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  41. // FIXME Replace the frame matching mechanism by driverless, per-axis Match
  42. template <class A, class B>
  43. struct max_i
  44. {
  45. constexpr static int value = gt_rank(A::value, B::value) ? 0 : 1; // 0 if ra wins, else 1
  46. };
  47. template <class T> using largest_i_tuple = mp::IndexOf<max_i, T>;
  48. // Get a list (per argument) of lists of live axes.
  49. // The last frame match is not done; that relies on rest axis handling of each argument (ignoring axis spec beyond their own rank). TODO Reexamine that.
  50. // Case where V has rank.
  51. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  52. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  53. {
  54. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "bad args");
  55. using T = std::tuple<Ti ...>;
  56. using R_ = std::tuple<Ri ...>;
  57. // TODO functions of arg rank, negative, inf.
  58. // live = number of live axes on this frame, for each argument.
  59. using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>-crank::value) ...>;
  60. static_assert(mp::apply<mp::andb, mp::map<ValidRank, live>>::value, "bad ranks");
  61. // select driver for this stage.
  62. constexpr static int driver = largest_i_tuple<live>::value;
  63. // add actual axes to result.
  64. using skips = mp::makelist<sizeof...(Ti), mp::int_t<skip>>;
  65. using FM = Framematch<W, T, mp::map<AddFrameAxes, R_, skips, live>,
  66. skip + mp::ref<live, driver>::value>;
  67. using R = typename FM::R;
  68. constexpr static int depth = mp::ref<live, driver>::value + FM::depth;
  69. // drill down in V to get innermost Op (cf [ra31]).
  70. template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); }
  71. };
  72. // Terminal case where V doesn't have rank (is a raw op()).
  73. template <class V, class ... Ti, class ... Ri, rank_t skip>
  74. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  75. {
  76. using R_ = std::tuple<Ri ...>;
  77. // TODO -crank::value when the actual verb rank is used (e.g. to use cell_iterator<A, that_rank> instead of just begin()).
  78. using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>) ...>;
  79. static_assert(mp::apply<mp::andb, mp::map<ValidRank, live>>::value, "bad ranks");
  80. constexpr static int driver = largest_i_tuple<live>::value;
  81. using skips = mp::makelist<sizeof...(Ti), mp::int_t<skip>>;
  82. using R = mp::map<AddFrameAxes, R_, skips, live>;
  83. constexpr static int depth = mp::ref<live, driver>::value;
  84. template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
  85. };
  86. template <class T>
  87. struct zerostride
  88. {
  89. constexpr static T f() { return T(0); }
  90. };
  91. template <class ... T>
  92. struct zerostride<std::tuple<T ...>>
  93. {
  94. constexpr static std::tuple<T ...> f() { return std::make_tuple(zerostride<T>::f() ...); }
  95. };
  96. // Wraps each argument of an expression using wrank.
  97. template <class LiveAxes, int depth, class A>
  98. struct ApplyFrames
  99. {
  100. A a;
  101. constexpr static int live(int k) { return mp::int_list_index<LiveAxes>(k); }
  102. template <class I>
  103. constexpr decltype(auto) at(I const & i)
  104. {
  105. return a.at(mp::map_indices<std::array<dim_t, mp::len<LiveAxes>>, LiveAxes>(i));
  106. }
  107. constexpr dim_t size(int k) const
  108. {
  109. int l = live(k);
  110. return l>=0 ? a.size(l) : DIM_BAD;
  111. }
  112. constexpr void adv(rank_t k, dim_t d)
  113. {
  114. if (int l = live(k); l>=0) {
  115. a.adv(l, d);
  116. }
  117. }
  118. constexpr auto stride(int k) const
  119. {
  120. int l = live(k);
  121. return l>=0 ? a.stride(l) : zerostride<decltype(a.stride(l))>::f();
  122. }
  123. constexpr bool keep_stride(dim_t step, int z, int j) const
  124. {
  125. int wz = live(z);
  126. int wj = live(j);
  127. return wz>=0 && wj>=0 && a.keep_stride(step, wz, wj);
  128. }
  129. constexpr decltype(auto) flat() { return a.flat(); }
  130. constexpr static dim_t size_s(int k)
  131. {
  132. int l = live(k);
  133. return l>=0 ? std::decay_t<A>::size_s(l) : DIM_BAD;
  134. }
  135. constexpr static rank_t rank() { return depth; } // TODO Invalid for RANK_ANY [ra07]
  136. constexpr static rank_t rank_s() { return depth; } // TODO Invalid for RANK_ANY [ra07]
  137. };
  138. // No-op case. TODO Maybe apply to any Iota<n> where n<=depth.
  139. // TODO If A is cell_iterator, etc. beat LiveAxes directly on that... same for an eventual transpose_expr<>.
  140. template <class LiveAxes, int depth, class A>
  141. decltype(auto) applyframes(A && a)
  142. {
  143. if constexpr (std::is_same_v<LiveAxes, mp::iota<depth>>) {
  144. return std::forward<A>(a);
  145. } else {
  146. return ApplyFrames<LiveAxes, depth, A> { std::forward<A>(a) };
  147. }
  148. }
  149. template <class V, class ... T, int ... i> inline constexpr
  150. auto ryn(mp::int_list<i ...>, V && v, T && ... t)
  151. {
  152. using FM = Framematch<V, std::tuple<T ...>>;
  153. return expr(FM::op(std::forward<V>(v)),
  154. applyframes<mp::ref<typename FM::R, i>, FM::depth>(std::forward<T>(t)) ...);
  155. }
  156. // TODO partial specialization means no universal ref :-/
  157. #define DEF_EXPR_VERB(MOD) \
  158. template <class cranks, class Op, class ... P> inline constexpr \
  159. auto expr(Verb<cranks, Op> MOD v, P && ... t) \
  160. { \
  161. return ryn(mp::iota<sizeof...(P)> {}, std::forward<decltype(v)>(v), std::forward<P>(t) ...); \
  162. }
  163. FOR_EACH(DEF_EXPR_VERB, &&, &, const &)
  164. #undef DEF_EXPR_VERB
  165. // ---------------------------
  166. // from, after APL, like (from) in guile-ploy
  167. // TODO integrate with is_beatable shortcuts, operator() in the various array types.
  168. // ---------------------------
  169. template <class I>
  170. struct index_rank_
  171. {
  172. using type = mp::int_t<std::decay_t<I>::rank_s()>; // see ra_traits for ra::types (?)
  173. static_assert(type::value!=RANK_ANY, "dynamic rank unsupported");
  174. static_assert(size_s<I>()!=DIM_BAD, "undelimited extent subscript unsupported");
  175. };
  176. template <class I> using index_rank = typename index_rank_<I>::type;
  177. template <class II, int drop, class Op> inline constexpr
  178. decltype(auto) from_partial(Op && op)
  179. {
  180. if constexpr (drop==mp::len<II>) {
  181. return std::forward<Op>(op);
  182. } else {
  183. return wrank(mp::append<mp::makelist<drop, mp::int_t<0>>, mp::drop<II, drop>> {},
  184. from_partial<II, drop+1>(std::forward<Op>(op)));
  185. }
  186. }
  187. // TODO we should be able to do better by slicing at each dimension, etc. But verb<> only supports rank-0 for the innermost op.
  188. template <class A, class ... I> inline constexpr
  189. auto from(A && a, I && ... i)
  190. {
  191. if constexpr (0==sizeof...(i)) {
  192. return a();
  193. } else if constexpr (1==sizeof...(i)) {
  194. // support dynamic rank for 1 arg only (see test in test/from.C).
  195. return expr(std::forward<A>(a), start(std::forward<I>(i) ...));
  196. } else {
  197. using II = mp::map<index_rank, mp::tuple<decltype(start(std::forward<I>(i))) ...>>;
  198. return expr(from_partial<II, 1>(std::forward<A>(a)), start(std::forward<I>(i)) ...);
  199. }
  200. }
  201. } // namespace ra
  202. #undef CHECK_BOUNDS