match.hh 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file match.hh
  3. /// @brief Prefix matching of array expression templates.
  4. // (c) Daniel Llorens - 2011-2013, 2015-2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/bootstrap.hh"
  11. namespace ra {
  12. inline constexpr
  13. bool gt_rank(rank_t ra, rank_t rb)
  14. {
  15. return rb==RANK_BAD
  16. ? 1
  17. : rb==RANK_ANY
  18. ? ra==RANK_ANY
  19. : ra==RANK_BAD
  20. ? 0
  21. : ra==RANK_ANY
  22. ? 1
  23. : ra>=rb;
  24. }
  25. inline constexpr
  26. bool gt_size(dim_t sa, dim_t sb)
  27. {
  28. return sb==DIM_BAD
  29. ? 1
  30. : sa==DIM_BAD
  31. ? 0
  32. : sb==DIM_ANY
  33. ? 1
  34. : (sa!=DIM_ANY && sa>=sb);
  35. }
  36. // TODO Allow infinite rank; need a special value of crank for that.
  37. inline constexpr
  38. rank_t dependent_cell_rank(rank_t rank, rank_t crank)
  39. {
  40. return crank>=0 ? crank // not dependent
  41. : rank==RANK_ANY ? RANK_ANY // defer
  42. : (rank+crank);
  43. }
  44. inline constexpr
  45. rank_t dependent_frame_rank(rank_t rank, rank_t crank)
  46. {
  47. return rank==RANK_ANY ? RANK_ANY // defer
  48. : crank>=0 ? (rank-crank) // not dependent
  49. : -crank;
  50. }
  51. inline constexpr
  52. dim_t chosen_size(dim_t sa, dim_t sb)
  53. {
  54. if (sa==DIM_BAD) {
  55. return sb;
  56. } else if (sb==DIM_BAD) {
  57. return sa;
  58. } else if (sa==DIM_ANY) {
  59. return sb;
  60. } else {
  61. return sa;
  62. }
  63. }
  64. // Abort if there is a static mismatch. Return 0 if if all the sizes are static. Return 1 if a runtime check is needed.
  65. template <class E>
  66. inline constexpr
  67. int check_expr_s()
  68. {
  69. using T = typename E::T;
  70. constexpr rank_t rs = E::rank_s();
  71. if constexpr (rs>=0) {
  72. constexpr auto fk =
  73. [](auto && fk, auto k_, auto valk)
  74. {
  75. // FIXME until something like P1045R1 = [](..., constexpr auto k_, ...)
  76. constexpr int k = k_;
  77. if constexpr (k<rs) {
  78. constexpr auto fi =
  79. [](auto && fi, auto i_, auto sk_, auto vali)
  80. {
  81. constexpr dim_t sk = sk_;
  82. constexpr int i = i_;
  83. if constexpr (i<mp::len<T>) {
  84. using Ti = std::decay_t<mp::ref<T, i>>;
  85. if constexpr (k<Ti::rank_s()) {
  86. constexpr dim_t si = Ti::size_s(k);
  87. static_assert(sk<0 || si<0 || si==sk, "mismatched static dimensions");
  88. return fi(fi, mp::int_t<i+1> {}, mp::int_t<chosen_size(sk, si)> {},
  89. mp::int_t<(1==vali || sk==DIM_ANY || si==DIM_ANY) ? 1 : 0> {});
  90. } else {
  91. return fi(fi, mp::int_t<i+1> {}, mp::int_t<sk> {}, vali);
  92. }
  93. } else {
  94. return vali;
  95. }
  96. };
  97. constexpr int vali = fi(fi, mp::int_t<0> {}, mp::int_t<DIM_BAD> {}, valk);
  98. return fk(fk, mp::int_t<k+1> {}, mp::int_t<vali> {});
  99. } else {
  100. return valk;
  101. }
  102. };
  103. return fk(fk, mp::int_t<0> {}, mp::int_t<0> {});
  104. } else {
  105. return 1;
  106. }
  107. }
  108. template <class E>
  109. inline constexpr
  110. bool check_expr(E const & e)
  111. {
  112. using T = typename E::T;
  113. rank_t rs = e.rank();
  114. for (int k=0; k!=rs; ++k) {
  115. auto fi =
  116. [&k, &e](auto && fi, auto i_, int sk)
  117. {
  118. constexpr int i = i_;
  119. if constexpr (i<mp::len<T>) {
  120. if (k<std::get<i>(e.t).rank()) {
  121. dim_t si = std::get<i>(e.t).size(k);
  122. RA_CHECK((sk==DIM_BAD || si==DIM_BAD || si==sk),
  123. " k ", k, " sk ", sk, " != ", si, ": mismatched dimensions");
  124. fi(fi, mp::int_t<i+1> {}, chosen_size(sk, si));
  125. } else {
  126. fi(fi, mp::int_t<i+1> {}, sk);
  127. }
  128. }
  129. };
  130. fi(fi, mp::int_t<0> {}, DIM_BAD);
  131. }
  132. // FIXME actually use this instead of relying on RA_CHECK throwing/aborting
  133. return true;
  134. }
  135. template <class T, class K=mp::iota<mp::len<T>>> struct MatchParent;
  136. template <class ... P, int ... I>
  137. struct MatchParent<std::tuple<P ...>, mp::int_list<I ...>>
  138. {
  139. using T = std::tuple<P ...>;
  140. T t;
  141. constexpr MatchParent(P ... p_): t(std::forward<P>(p_) ...)
  142. {
  143. if constexpr (check_expr_s<MatchParent>()) {
  144. RA_CHECK(check_expr(*this)); // TODO Maybe do this on ply?
  145. }
  146. }
  147. template <class T> struct box { using type = T; };
  148. // rank of largest subexpr. This is true for either prefix or suffix match.
  149. constexpr static rank_t rank_s()
  150. {
  151. return mp::fold_tuple(RANK_BAD, mp::map<box, T> {},
  152. [](rank_t r, auto a)
  153. {
  154. constexpr rank_t ar = ra::rank_s<typename decltype(a)::type>();
  155. return gt_rank(r, ar) ? r : ar;
  156. });
  157. }
  158. constexpr rank_t rank() const
  159. {
  160. if constexpr (constexpr rank_t rs=rank_s(); rs==RANK_ANY) {
  161. return mp::fold_tuple(RANK_BAD, t,
  162. [](rank_t r, auto && a)
  163. {
  164. rank_t ar = a.rank();
  165. assert(ar!=RANK_ANY); // cannot happen at runtime
  166. return gt_rank(r, ar) ? r : ar;
  167. });
  168. } else {
  169. return rs;
  170. }
  171. }
  172. // any size which is not DIM_BAD.
  173. constexpr static dim_t size_s(int k)
  174. {
  175. dim_t s = mp::fold_tuple(DIM_BAD, mp::map<box, T> {},
  176. [&k](dim_t s, auto a)
  177. {
  178. using A = std::decay_t<typename decltype(a)::type>;
  179. constexpr rank_t ar = A::rank_s();
  180. if (s!=DIM_BAD) {
  181. return s;
  182. } else if (ar>=0 && k>=ar) {
  183. return s;
  184. } else {
  185. dim_t zz = A::size_s(k);
  186. return zz;
  187. }
  188. });
  189. return s;
  190. }
  191. // do early exit with fold_tuple (and with size_s(k)).
  192. constexpr dim_t size(int k) const
  193. {
  194. if (dim_t ss=size_s(k); ss==DIM_ANY) {
  195. auto f = [this, &k](auto && f, auto i_)
  196. {
  197. constexpr int i = i_;
  198. if constexpr (i<std::tuple_size_v<T>) {
  199. auto const & a = std::get<i>(this->t);
  200. if (k<a.rank()) {
  201. dim_t as = a.size(k);
  202. if (as!=DIM_BAD) {
  203. assert(as!=DIM_ANY); // cannot happen at runtime
  204. return as;
  205. } else {
  206. return f(f, mp::int_t<i+1> {});
  207. }
  208. } else {
  209. return f(f, mp::int_t<i+1> {});
  210. }
  211. } else {
  212. assert(0);
  213. return DIM_BAD;
  214. }
  215. };
  216. return f(f, mp::int_t<0> {});
  217. } else {
  218. return ss;
  219. }
  220. }
  221. constexpr void adv(rank_t k, dim_t d)
  222. {
  223. (std::get<I>(t).adv(k, d), ...);
  224. }
  225. constexpr auto stride(int i) const
  226. {
  227. return std::make_tuple(std::get<I>(t).stride(i) ...);
  228. }
  229. };
  230. // forward decl in atom.hh. Split in MatchParent/Match to allow static keep_stride.
  231. // FIXME keep an eye on https://gcc.gnu.org/bugzilla/show_bug.cgi?id=96164
  232. template <class T, class K=mp::iota<mp::len<T>>> struct Match;
  233. template <class ... P, int ... I>
  234. requires (!(requires (dim_t d, rank_t i, rank_t j) { P::keep_stride(d, i, j); } && ...))
  235. struct Match<std::tuple<P ...>, mp::int_list<I ...>>: public MatchParent<std::tuple<P ...>, mp::int_list<I ...>>
  236. {
  237. using MatchParent<std::tuple<P ...>, mp::int_list<I ...>>::MatchParent;
  238. using MatchParent<std::tuple<P ...>, mp::int_list<I ...>>::t;
  239. constexpr bool keep_stride(dim_t st, int z, int j) const
  240. {
  241. return (std::get<I>(t).keep_stride(st, z, j) && ...);
  242. }
  243. };
  244. template <class ... P, int ... I>
  245. requires (requires (dim_t d, rank_t i, rank_t j) { P::keep_stride(d, i, j); } && ...)
  246. struct Match<std::tuple<P ...>, mp::int_list<I ...>>: public MatchParent<std::tuple<P ...>, mp::int_list<I ...>>
  247. {
  248. using MatchParent<std::tuple<P ...>, mp::int_list<I ...>>::MatchParent;
  249. constexpr static bool keep_stride(dim_t st, int z, int j)
  250. {
  251. return (std::decay_t<P>::keep_stride(st, z, j) && ...);
  252. }
  253. };
  254. } // namespace ra