wrank.H 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. // (c) Daniel Llorens - 2013-2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file wrank.H
  7. /// @brief Rank conjunction for expression templates.
  8. #pragma once
  9. #include "ra/expr.H"
  10. #ifdef RA_CHECK_BOUNDS
  11. #define RA_CHECK_BOUNDS_RA_WRANK RA_CHECK_BOUNDS
  12. #else
  13. #ifndef RA_CHECK_BOUNDS_RA_WRANK
  14. #define RA_CHECK_BOUNDS_RA_WRANK 1
  15. #endif
  16. #endif
  17. #if RA_CHECK_BOUNDS_RA_WRANK==0
  18. #define CHECK_BOUNDS( cond )
  19. #else
  20. #define CHECK_BOUNDS( cond ) assert( cond )
  21. #endif
  22. // TODO Make it work with fixed size types.
  23. // TODO Make it work with var rank types.
  24. namespace ra {
  25. template <class cranks, class Op>
  26. struct Verb
  27. {
  28. using R = cranks;
  29. Op op;
  30. };
  31. template <class cranks, class Op> constexpr inline auto
  32. wrank(cranks cranks_, Op && op)
  33. {
  34. return Verb<cranks, Op> { std::forward<Op>(op) };
  35. }
  36. template <rank_t ... crank, class Op> constexpr inline auto
  37. wrank(Op && op)
  38. {
  39. return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
  40. }
  41. template <class A>
  42. struct ValidRank
  43. {
  44. using type = mp::int_t<(A::value>=0)>;
  45. };
  46. template <class R, class skip, class frank>
  47. struct AddFrameAxes
  48. {
  49. using type = mp::Append_<R, mp::Iota_<frank::value, skip::value>>;
  50. };
  51. template <class V, class T, class R=mp::MakeList_<mp::len<T>, mp::nil>, rank_t skip=0>
  52. struct Framematch_def;
  53. template <class V, class T, class R=mp::MakeList_<mp::len<T>, mp::nil>, rank_t skip=0>
  54. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  55. // Get a list (per argument) of lists of live axes.
  56. // The last frame match is not done; that relies on rest axis handling of each argument (ignoring axis spec beyond their own rank). TODO Reexamine that.
  57. // Case where V has rank.
  58. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  59. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  60. {
  61. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "bad args");
  62. using T = std::tuple<Ti ...>;
  63. using R_ = std::tuple<Ri ...>;
  64. // TODO functions of arg rank, negative, inf.
  65. // live = number of live axes on this frame, for each argument.
  66. using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>-crank::value) ...>;
  67. static_assert(mp::Apply_<mp::And, mp::Map_<ValidRank, live>>::value, "bad ranks");
  68. // select driver for this stage.
  69. constexpr static int driver = largest_i_tuple<live>::value;
  70. // add actual axes to result.
  71. using skips = mp::MakeList_<sizeof...(Ti), mp::int_t<skip>>;
  72. using FM = Framematch<W, T, mp::Map_<AddFrameAxes, R_, skips, live>,
  73. skip + mp::Ref_<live, driver>::value>;
  74. using R = typename FM::R;
  75. constexpr static int depth = mp::Ref_<live, driver>::value + FM::depth;
  76. // drill down in V to get innermost Op (cf [ra31]).
  77. template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); }
  78. };
  79. // Terminal case where V doesn't have rank (is a raw op()).
  80. template <class V, class ... Ti, class ... Ri, rank_t skip>
  81. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  82. {
  83. using R_ = std::tuple<Ri ...>;
  84. // TODO -crank::value when the actual verb rank is used (e.g. to use ra_iterator<A, that_rank> instead of just begin()).
  85. using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>) ...>;
  86. static_assert(mp::Apply_<mp::And, mp::Map_<ValidRank, live>>::value, "bad ranks");
  87. constexpr static int driver = largest_i_tuple<live>::value;
  88. using skips = mp::MakeList_<sizeof...(Ti), mp::int_t<skip>>;
  89. using R = mp::Map_<AddFrameAxes, R_, skips, live>;
  90. constexpr static int depth = mp::Ref_<live, driver>::value;
  91. template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
  92. };
  93. template <class T>
  94. struct zerostride
  95. {
  96. constexpr static T f() { return T(0); }
  97. };
  98. template <class ... Ti>
  99. struct zerostride<std::tuple<Ti ...>>
  100. {
  101. constexpr static std::tuple<Ti ...> f() { return std::make_tuple(zerostride<Ti>::f() ...); }
  102. };
  103. // Wraps each argument of an expression using wrank.
  104. // no shape(), size_s(), rank_s(), rank() -> this is above.
  105. template <class LiveAxes, int depth, class A>
  106. struct ApplyFrames
  107. {
  108. A a;
  109. constexpr static int live(int k) { return mp::on_tuple<LiveAxes>::index(k); }
  110. template <class I>
  111. constexpr decltype(auto) at(I const & i)
  112. {
  113. return a.at(mp::map_indices<LiveAxes, std::array<dim_t, mp::len<LiveAxes>>>::f(i));
  114. }
  115. constexpr dim_t size(int k) const
  116. {
  117. int l = live(k);
  118. return l>=0 ? a.size(l) : DIM_BAD;
  119. }
  120. constexpr void adv(rank_t k, dim_t d)
  121. {
  122. int l = live(k);
  123. if (l>=0) {
  124. a.adv(l, d);
  125. }
  126. }
  127. constexpr auto stride(int k) const
  128. {
  129. int l = live(k);
  130. return l>=0 ? a.stride(l) : zerostride<decltype(a.stride(l))>::f();
  131. }
  132. constexpr bool keep_stride(dim_t step, int z, int j) const
  133. {
  134. int wz = live(z);
  135. int wj = live(j);
  136. return wz>=0 && wj>=0 && a.keep_stride(step, wz, wj);
  137. }
  138. constexpr decltype(auto) flat() { return a.flat(); }
  139. constexpr decltype(auto) flat() const { return a.flat(); }
  140. };
  141. template <class LiveAxes, int depth, class Enable=void>
  142. struct applyframes
  143. {
  144. template <class A>
  145. static decltype(auto) f(A && a)
  146. {
  147. return ApplyFrames<LiveAxes, depth, A> { std::forward<A>(a) };
  148. }
  149. };
  150. // No-op case. TODO Maybe apply to any Iota<n> where n<=depth.
  151. // TODO If A is ra_iterator, etc. beat LiveAxes directly on that... same for an eventual transpose_expr<>.
  152. template <class LiveAxes, int depth>
  153. struct applyframes<LiveAxes, depth, std::enable_if_t<std::is_same<LiveAxes, mp::Iota_<depth>>::value>>
  154. {
  155. template <class A>
  156. static decltype(auto) f(A && a)
  157. {
  158. return std::forward<A>(a);
  159. }
  160. };
  161. // like Expr, except don't do driver selection here, but leave it to the args, as with Expr::adv(k, d). The args may need to be ApplyFrames... don't know yet.
  162. // forward decl in atom.H.
  163. template <class FM, class Op, class ... P, int ... I>
  164. struct Ryn<FM, Op, std::tuple<P ...>, std::integer_sequence<int, I ...>>
  165. {
  166. Op op;
  167. std::tuple<P ...> t;
  168. template <int iarg>
  169. bool check()
  170. {
  171. for (int k=0; k!=rank(); ++k) { // TODO with static rank or sizes, can peval.
  172. dim_t s0 = size(k);
  173. dim_t sk = std::get<iarg>(t).size(k);
  174. if (sk!=s0 && sk!=DIM_BAD) { // TODO See Expr::check(); maybe just sk>=0.
  175. return false;
  176. }
  177. }
  178. return true;
  179. }
  180. constexpr Ryn(Op op_, P ... t_): op(std::forward<Op>(op_)), t(std::forward<P>(t_) ...)
  181. {
  182. CHECK_BOUNDS(check<I>() && ... && "mismatched shapes");
  183. }
  184. template <class J>
  185. constexpr decltype(auto) at(J const & i)
  186. {
  187. return op(std::get<I>(t).at(i) ...);
  188. }
  189. constexpr void adv(rank_t k, dim_t d)
  190. {
  191. (std::get<I>(t).adv(k, d), ...);
  192. }
  193. constexpr bool keep_stride(dim_t step, int z, int j) const
  194. {
  195. return (std::get<I>(t).keep_stride(step, z, j) && ...);
  196. }
  197. constexpr auto stride(int i) const
  198. {
  199. return std::make_tuple(std::get<I>(t).stride(i) ...);
  200. }
  201. constexpr auto flat()
  202. {
  203. return ra::flat(op, std::get<I>(t).flat() ...);
  204. }
  205. constexpr auto flat() const { return flat(); }
  206. // Use the first arg that gives size(k)>=0; valid by ApplyFrame.
  207. // TODO if k were static, we could pick the driving arg from axisdrivers. Only need bool from that.
  208. template <int iarg=0>
  209. std::enable_if_t<(iarg<sizeof...(P)), dim_t>
  210. constexpr size(int k) const
  211. {
  212. dim_t s = std::get<iarg>(t).size(k);
  213. return s>=0 ? s : size<iarg+1>(k);
  214. }
  215. template <int iarg>
  216. std::enable_if_t<(iarg==sizeof...(P)), dim_t>
  217. constexpr size(int k) const
  218. {
  219. abort(); return DIM_BAD;
  220. }
  221. constexpr static dim_t size_s() { return DIM_ANY; } // BUG
  222. constexpr static rank_t rank() { return FM::depth; } // TODO Invalid for RANK_ANY
  223. constexpr static rank_t rank_s() { return FM::depth; } // TODO Invalid for RANK_ANY
  224. constexpr auto shape() const
  225. {
  226. std::array<dim_t, FM::depth> s {};
  227. for (int k=0; k!=FM::depth; ++k) {
  228. s[k] = size(k);
  229. CHECK_BOUNDS(s[k]!=DIM_BAD);
  230. }
  231. return s;
  232. }
  233. // forward to make sure value y is not misused as ref. Cf. test-ra-8.C.
  234. #define DEF_RYN_ASSIGNOPS(OP) \
  235. template <class X> void operator OP(X && x) \
  236. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  237. FOR_EACH(DEF_RYN_ASSIGNOPS, =, *=, +=, -=, /=)
  238. #undef DEF_RYN_ASSIGNOPS
  239. };
  240. template <class FM, class Op, class ... P> inline
  241. constexpr auto ryn(Op && op, P && ... t)
  242. {
  243. return Ryn<FM, Op, std::tuple<P ...>> { std::forward<Op>(op), std::forward<P>(t) ... }; // (note 1)
  244. }
  245. template <class K>
  246. struct number_ryn;
  247. template <int ... I>
  248. struct number_ryn<std::integer_sequence<int, I ...>>
  249. {
  250. template <class V, class ... P> constexpr static
  251. auto f(V && v, P && ... t)
  252. {
  253. using FM = Framematch<V, std::tuple<P ...>>;
  254. return ryn<FM>(FM::op(std::forward<V>(v)), applyframes<mp::Ref_<typename FM::R, I>, FM::depth>::f(std::forward<P>(t)) ...);
  255. }
  256. };
  257. // TODO partial specialization means no universal ref :-/
  258. #define DEF_EXPR_VERB(MOD) \
  259. template <class cranks, class Op, class ... P> inline constexpr \
  260. auto expr(Verb<cranks, Op> MOD v, P && ... t) \
  261. { \
  262. return number_ryn<std::make_integer_sequence<int, sizeof...(P)>>::f(std::forward<decltype(v)>(v), std::forward<P>(t) ...); \
  263. }
  264. FOR_EACH(DEF_EXPR_VERB, &&, &, const &)
  265. #undef DEF_EXPR_VERB
  266. // ---------------------------
  267. // from, after APL, like (from) in guile-ploy
  268. // TODO integrate with is_beatable shortcuts, operator() in the various array types.
  269. // ---------------------------
  270. template <class I>
  271. struct index_rank
  272. {
  273. using type = mp::int_t<std::decay_t<I>::rank_s()>; // see ra_traits for ra::types (?)
  274. static_assert(type::value!=RANK_ANY, "dynamic rank unsupported");
  275. static_assert(std::decay_t<I>::size_s()!=DIM_BAD, "undelimited extent subscript unsupported");
  276. };
  277. template <class II, int drop, class Enable=void>
  278. struct from_partial
  279. {
  280. template <class Op>
  281. static decltype(auto) make(Op && op)
  282. {
  283. return wrank(mp::Append_<mp::MakeList_<drop, mp::int_t<0>>, mp::Drop_<II, drop>> {},
  284. from_partial<II, drop+1>::make(std::forward<Op>(op)));
  285. }
  286. };
  287. template <class II, int drop>
  288. struct from_partial<II, drop, std::enable_if_t<drop==mp::len<II>>>
  289. {
  290. template <class Op>
  291. static decltype(auto) make(Op && op)
  292. {
  293. return std::forward<Op>(op);
  294. }
  295. };
  296. // FIXME the general case fails in from_partial.
  297. template <class A> inline constexpr
  298. auto from(A && a)
  299. {
  300. return a();
  301. }
  302. // Support dynamic rank for 1 arg only (see test in test-from.C).
  303. template <class A, class I0> inline constexpr
  304. auto from(A && a, I0 && i0)
  305. {
  306. return expr(std::forward<A>(a), start(std::forward<I0>(i0)));
  307. }
  308. // TODO we should be able to do better by slicing at each dimension, etc. But verb<> only supports rank-0 for the innermost op.
  309. template <class A, class ... I> inline constexpr
  310. auto from(A && a, I && ... i)
  311. {
  312. using II = mp::Map_<index_rank, mp::tuple<decltype(start(std::forward<I>(i))) ...>>;
  313. return expr(from_partial<II, 1>::make(std::forward<A>(a)), start(std::forward<I>(i)) ...);
  314. }
  315. } // namespace ra
  316. #undef CHECK_BOUNDS
  317. #undef RA_CHECK_BOUNDS_RA_WRANK