expr.hh 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression templates with prefix matching.
  3. // (c) Daniel Llorens - 2011-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include <cassert>
  10. #include <functional>
  11. #include "bootstrap.hh"
  12. // --------------------
  13. // error handling. See examples/throw.cc for how to customize.
  14. // --------------------
  15. #include <iostream> // might not be needed with a different RA_ASSERT.
  16. #ifndef RA_ASSERT
  17. #define RA_ASSERT(cond, ...) \
  18. { \
  19. if (std::is_constant_evaluated()) { \
  20. assert(cond /* FIXME show args */); \
  21. } else { \
  22. if (!(cond)) [[unlikely]] { \
  23. std::cerr << ra::format("**** ra (", std::source_location::current(), "): ", ##__VA_ARGS__, " ****") << std::endl; \
  24. std::abort(); \
  25. } \
  26. } \
  27. }
  28. #endif
  29. #if defined(RA_DO_CHECK) && RA_DO_CHECK==0
  30. #define RA_CHECK( ... )
  31. #else
  32. #define RA_CHECK( ... ) RA_ASSERT( __VA_ARGS__ )
  33. #endif
  34. #define RA_AFTER_CHECK Yes
  35. namespace ra {
  36. constexpr bool inside(dim_t i, dim_t b) { return 0<=i && i<b; }
  37. // --------------------
  38. // assign ops for settable iterators. Might be different for e.g. Views.
  39. // --------------------
  40. // Forward to forbid misusing value y as ref [ra5].
  41. #define RA_ASSIGNOPS_LINE(OP) \
  42. for_each([](auto && y, auto && x) { RA_FWD(y) OP x; }, *this, x)
  43. #define RA_ASSIGNOPS(OP) \
  44. constexpr void operator OP(auto && x) { RA_ASSIGNOPS_LINE(OP); }
  45. // But see local ASSIGNOPS elsewhere.
  46. #define RA_ASSIGNOPS_DEFAULT_SET \
  47. FOR_EACH(RA_ASSIGNOPS, =, *=, +=, -=, /=)
  48. // Restate for expression classes since a template doesn't replace the copy assignment op.
  49. #define RA_ASSIGNOPS_SELF(TYPE) \
  50. TYPE & operator=(TYPE && x) { RA_ASSIGNOPS_LINE(=); return *this; } \
  51. TYPE & operator=(TYPE const & x) { RA_ASSIGNOPS_LINE(=); return *this; } \
  52. constexpr TYPE(TYPE && x) = default; \
  53. constexpr TYPE(TYPE const & x) = default;
  54. // --------------------
  55. // terminal types
  56. // --------------------
  57. // Rank-0 IteratorConcept. Can be used on foreign objects, or as alternative to the rank conjunction.
  58. // We still want f(scalar(C)) to be f(C) and not map(f, C), this is controlled by tomap/toreduce.
  59. template <class C>
  60. struct Scalar
  61. {
  62. C c;
  63. RA_ASSIGNOPS_DEFAULT_SET
  64. consteval static rank_t rank() { return 0; }
  65. constexpr static dim_t len_s(int k) { std::abort(); }
  66. constexpr static dim_t len(int k) { std::abort(); }
  67. constexpr static dim_t step(int k) { return 0; }
  68. constexpr static void adv(rank_t k, dim_t d) {}
  69. constexpr static bool keep_step(dim_t st, int z, int j) { return true; }
  70. constexpr decltype(auto) at(auto && j) const { return c; }
  71. constexpr C & operator*() requires (std::is_lvalue_reference_v<C>) { return c; } // [ra37]
  72. constexpr C const & operator*() requires (!std::is_lvalue_reference_v<C>) { return c; }
  73. constexpr C const & operator*() const { return c; } // [ra39]
  74. constexpr static int save() { return 0; }
  75. constexpr static void load(int) {}
  76. constexpr static void mov(dim_t d) {}
  77. };
  78. template <class C> constexpr auto
  79. scalar(C && c) { return Scalar<C> { RA_FWD(c) }; }
  80. template <class N> constexpr int
  81. maybe_any = []{
  82. if constexpr (is_constant<N>) {
  83. return N::value;
  84. } else {
  85. static_assert(std::is_integral_v<N> || !std::is_same_v<N, bool>);
  86. return ANY;
  87. }
  88. }();
  89. // IteratorConcept for foreign rank 1 objects.
  90. template <std::bidirectional_iterator I, class N>
  91. struct Ptr
  92. {
  93. static_assert(is_constant<N> || 0==rank_s<N>());
  94. constexpr static dim_t nn = maybe_any<N>;
  95. static_assert(nn==ANY || nn>=0 || nn==BAD);
  96. I i;
  97. [[no_unique_address]] N const n = {};
  98. constexpr Ptr(I i, N n): i(i), n(n) {}
  99. RA_ASSIGNOPS_SELF(Ptr)
  100. RA_ASSIGNOPS_DEFAULT_SET
  101. consteval static rank_t rank() { return 1; }
  102. constexpr static dim_t len_s(int k) { return nn; } // len(k==0) or step(k>=0)
  103. constexpr static dim_t len(int k) requires (nn!=ANY) { return len_s(k); }
  104. constexpr dim_t len(int k) const requires (nn==ANY) { return n; }
  105. constexpr static dim_t step(int k) { return k==0 ? 1 : 0; }
  106. constexpr void adv(rank_t k, dim_t d) { i += step(k) * d; }
  107. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  108. constexpr decltype(auto) at(auto && j) const requires (std::random_access_iterator<I>)
  109. {
  110. RA_CHECK(BAD==nn || inside(j[0], n), "Out of range for len[0]=", n, ": ", j[0], ".");
  111. return i[j[0]];
  112. }
  113. constexpr decltype(auto) operator*() const { return *i; }
  114. constexpr auto save() const { return i; }
  115. constexpr void load(I ii) { i = ii; }
  116. constexpr void mov(dim_t d) { i += d; }
  117. };
  118. template <class X> using iota_arg = std::conditional_t<is_constant<std::decay_t<X>> || is_scalar<std::decay_t<X>>, std::decay_t<X>, X>;
  119. template <class I, class N=dim_c<BAD>>
  120. constexpr auto
  121. ptr(I && i, N && n = N {})
  122. {
  123. // not decay_t bc of builtin arrays.
  124. if constexpr (std::ranges::bidirectional_range<std::remove_reference_t<I>>) {
  125. static_assert(std::is_same_v<dim_c<BAD>, N>, "Object has own length.");
  126. constexpr dim_t s = size_s<I>();
  127. if constexpr (ANY==s) {
  128. return ptr(std::begin(RA_FWD(i)), std::ssize(i));
  129. } else {
  130. return ptr(std::begin(RA_FWD(i)), ic<s>);
  131. }
  132. } else if constexpr (std::bidirectional_iterator<std::decay_t<I>>) {
  133. if constexpr (std::is_integral_v<N>) {
  134. RA_CHECK(n>=0, "Bad ptr length ", n, ".");
  135. }
  136. return Ptr<std::decay_t<I>, iota_arg<N>> { i, RA_FWD(n) };
  137. } else {
  138. static_assert(always_false<I>, "Bad type for ptr().");
  139. }
  140. }
  141. // Sequence and IteratorConcept for same. Iota isn't really a terminal, but its exprs must all have rank 0.
  142. // FIXME w is a custom Reframe mechanism inherited from TensorIndex. Generalize/unify
  143. // FIXME Sequence should be its own type, we can't represent a ct origin bc IteratorConcept interface takes up i.
  144. template <int w, class N_, class O, class S_>
  145. struct Iota
  146. {
  147. using N = std::decay_t<N_>;
  148. using S = std::decay_t<S_>;
  149. static_assert(w>=0);
  150. static_assert(is_constant<S> || 0==rank_s<S>());
  151. static_assert(is_constant<N> || 0==rank_s<N>());
  152. constexpr static dim_t nn = maybe_any<N>;
  153. static_assert(nn==ANY || nn>=0 || nn==BAD);
  154. [[no_unique_address]] N const n = {};
  155. O i = {};
  156. [[no_unique_address]] S const s = {};
  157. constexpr static S gets() requires (is_constant<S>) { return S {}; }
  158. constexpr O gets() const requires (!is_constant<S>) { return s; }
  159. consteval static rank_t rank() { return w+1; }
  160. constexpr static dim_t len_s(int k) { return k==w ? nn : BAD; } // len(0<=k<=w) or step(0<=k)
  161. constexpr static dim_t len(int k) requires (is_constant<N>) { return len_s(k); }
  162. constexpr dim_t len(int k) const requires (!is_constant<N>) { return k==w ? n : BAD; }
  163. constexpr static dim_t step(rank_t k) { return k==w ? 1 : 0; }
  164. constexpr void adv(rank_t k, dim_t d) { i += O(step(k) * d) * O(s); }
  165. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  166. constexpr auto at(auto && j) const
  167. {
  168. RA_CHECK(BAD==nn || inside(j[0], n), "Out of range for len[0]=", n, ": ", j[0], ".");
  169. return i + O(j[w])*O(s);
  170. }
  171. constexpr O operator*() const { return i; }
  172. constexpr O save() const { return i; }
  173. constexpr void load(O ii) { i = ii; }
  174. constexpr void mov(dim_t d) { i += O(d)*O(s); }
  175. };
  176. template <int w=0, class O=dim_t, class N=dim_c<BAD>, class S=dim_c<1>>
  177. constexpr auto
  178. iota(N && n = N {}, O && org = 0,
  179. S && s = [] {
  180. if constexpr (std::is_integral_v<S>) {
  181. return S(1);
  182. } else if constexpr (is_constant<S>) {
  183. static_assert(1==S::value);
  184. return S {};
  185. } else {
  186. static_assert(always_false<S>, "Bad step type for Iota.");
  187. }
  188. }())
  189. {
  190. if constexpr (std::is_integral_v<N>) {
  191. RA_CHECK(n>=0, "Bad iota length ", n, ".");
  192. }
  193. return Iota<w, iota_arg<N>, iota_arg<O>, iota_arg<S>> { RA_FWD(n), RA_FWD(org), RA_FWD(s) };
  194. }
  195. #define DEF_TENSORINDEX(w) constexpr auto JOIN(_, w) = iota<w>();
  196. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  197. #undef DEF_TENSORINDEX
  198. RA_IS_DEF(is_iota, false)
  199. // BAD is excluded from beating to allow B = A(... i ...) to use B's len. FIXME find a way?
  200. template <class N, class O, class S>
  201. constexpr bool is_iota_def<Iota<0, N, O, S>> = (BAD != Iota<0, N, O, S>::nn);
  202. constexpr bool
  203. inside(is_iota auto const & i, dim_t l)
  204. {
  205. return (inside(i.i, l) && inside(i.i+(i.n-1)*i.s, l)) || (0==i.n /* don't bother */);
  206. }
  207. constexpr struct Len
  208. {
  209. consteval static rank_t rank() { return 0; }
  210. constexpr static dim_t len_s(int k) { std::abort(); }
  211. constexpr static dim_t len(int k) { std::abort(); }
  212. constexpr static dim_t step(int k) { std::abort(); }
  213. constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
  214. constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
  215. constexpr dim_t operator*() const { std::abort(); }
  216. constexpr static int save() { std::abort(); }
  217. constexpr static void load(int) { std::abort(); }
  218. constexpr static void mov(dim_t d) { std::abort(); }
  219. } len;
  220. // protect exprs with Len from reduction.
  221. template <> constexpr bool is_special_def<Len> = true;
  222. RA_IS_DEF(has_len, false);
  223. // --------------
  224. // making Iterators
  225. // --------------
  226. // TODO arbitrary exprs?
  227. template <int cr>
  228. constexpr auto
  229. iter(SliceConcept auto && a) { return RA_FWD(a).template iter<cr>(); }
  230. constexpr void
  231. start(auto && t) { static_assert(always_false<decltype(t)>, "Cannot start() type."); }
  232. constexpr auto
  233. start(is_fov auto && t) { return ra::ptr(RA_FWD(t)); }
  234. template <class T>
  235. constexpr auto
  236. start(std::initializer_list<T> v) { return ra::ptr(v.begin(), v.size()); }
  237. constexpr auto
  238. start(is_scalar auto && t) { return ra::scalar(RA_FWD(t)); }
  239. // forward declare for Match; implemented in small.hh.
  240. constexpr auto
  241. start(is_builtin_array auto && t);
  242. // neither CellBig nor CellSmall will retain rvalues [ra4].
  243. constexpr auto
  244. start(SliceConcept auto && t) { return iter<0>(RA_FWD(t)); }
  245. RA_IS_DEF(is_ra_scalar, (std::same_as<A, Scalar<decltype(std::declval<A>().c)>>))
  246. // iterators need to be start()ed on each use [ra35].
  247. template <class T> requires (is_iterator<T> && !is_ra_scalar<T>)
  248. constexpr auto
  249. start(T & t) { return t; }
  250. // FIXME const Iterator would still be unusable after start()
  251. constexpr decltype(auto)
  252. start(is_iterator auto && t) { return RA_FWD(t); }
  253. // --------------------
  254. // prefix match
  255. // --------------------
  256. constexpr rank_t
  257. choose_rank(rank_t ra, rank_t rb) { return BAD==rb ? ra : BAD==ra ? rb : ANY==ra ? ra : ANY==rb ? rb : std::max(ra, rb); }
  258. // pick first if mismatch (see below). FIXME maybe return invalid.
  259. constexpr dim_t
  260. choose_len(dim_t sa, dim_t sb) { return BAD==sa ? sb : BAD==sb ? sa : ANY==sa ? sb : sa; }
  261. template <bool checkp, class T, class K=mp::iota<mp::len<T>>> struct Match;
  262. template <bool checkp, IteratorConcept ... P, int ... I>
  263. struct Match<checkp, std::tuple<P ...>, mp::int_list<I ...>>
  264. {
  265. std::tuple<P ...> t;
  266. // rank of largest subexpr
  267. constexpr static rank_t rs = [] { rank_t r=BAD; return ((r=choose_rank(r, ra::rank_s<P>())), ...); }();
  268. // 0: fail, 1: rt, 2: pass
  269. consteval static int
  270. check_s()
  271. {
  272. if constexpr (sizeof...(P)<2) {
  273. return 2;
  274. } else if constexpr (ANY==rs) {
  275. return 1; // FIXME can be tightened to 2 if all args are rank 0 save one
  276. } else {
  277. bool tbc = false;
  278. for (int k=0; k<rs; ++k) {
  279. dim_t ls = len_s(k);
  280. if (((k<ra::rank_s<P>() && ls!=choose_len(std::decay_t<P>::len_s(k), ls)) || ...)) {
  281. return 0;
  282. } else {
  283. int anyk = ((k<ra::rank_s<P>() && (ANY==std::decay_t<P>::len_s(k))) + ...);
  284. int fixk = ((k<ra::rank_s<P>() && (0<=std::decay_t<P>::len_s(k))) + ...);
  285. tbc = tbc || (anyk>0 && anyk+fixk>1);
  286. }
  287. }
  288. return tbc ? 1 : 2;
  289. }
  290. }
  291. constexpr bool
  292. check() const
  293. {
  294. if constexpr (sizeof...(P)<2) {
  295. return true;
  296. } else if constexpr (constexpr int c = check_s(); 0==c) {
  297. return false;
  298. } else if constexpr (1==c) {
  299. for (int k=0; k<rank(); ++k) {
  300. dim_t ls = len(k);
  301. if (((k<ra::rank(std::get<I>(t)) && ls!=choose_len(std::get<I>(t).len(k), ls)) || ...)) {
  302. RA_CHECK(!checkp, "Mismatch on axis ", k, " [", (std::array { std::get<I>(t).len(k) ... }), "].");
  303. return false;
  304. }
  305. }
  306. }
  307. return true;
  308. }
  309. constexpr
  310. Match(P ... p_): t(p_ ...) // [ra1]
  311. {
  312. // TODO Maybe on ply, would make checkp unnecessary, make agree_xxx() unnecessary.
  313. if constexpr (checkp && !(has_len<P> || ...)) {
  314. static_assert(check_s(), "Shape mismatch.");
  315. RA_CHECK(check());
  316. }
  317. }
  318. consteval static rank_t
  319. rank() requires (ANY!=rs)
  320. {
  321. return rs;
  322. }
  323. constexpr rank_t
  324. rank() const requires (ANY==rs)
  325. {
  326. rank_t r = BAD;
  327. ((r = choose_rank(r, ra::rank(std::get<I>(t)))), ...);
  328. assert(ANY!=r); // not at runtime
  329. return r;
  330. }
  331. // first nonnegative size, if none first ANY, if none then BAD
  332. constexpr static dim_t
  333. len_s(int k)
  334. {
  335. auto f = [&k]<class A>(dim_t s) {
  336. constexpr rank_t ar = ra::rank_s<A>();
  337. return (ar<0 || k<ar) ? choose_len(s, A::len_s(k)) : s;
  338. };
  339. dim_t s = BAD; ((s>=0 ? s : s = f.template operator()<std::decay_t<P>>(s)), ...);
  340. return s;
  341. }
  342. constexpr static dim_t
  343. len(int k) requires (requires (int kk) { P::len(kk); } && ...)
  344. {
  345. return len_s(k);
  346. }
  347. constexpr dim_t
  348. len(int k) const requires (!(requires (int kk) { P::len(kk); } && ...))
  349. {
  350. auto f = [&k](dim_t s, auto const & a) {
  351. return k<ra::rank(a) ? choose_len(s, a.len(k)) : s;
  352. };
  353. dim_t s = BAD; ((s>=0 ? s : s = f(s, std::get<I>(t))), ...);
  354. assert(ANY!=s); // not at runtime
  355. return s;
  356. }
  357. constexpr auto
  358. step(int i) const
  359. {
  360. return std::make_tuple(std::get<I>(t).step(i) ...);
  361. }
  362. constexpr void
  363. adv(rank_t k, dim_t d)
  364. {
  365. (std::get<I>(t).adv(k, d), ...);
  366. }
  367. constexpr bool
  368. keep_step(dim_t st, int z, int j) const
  369. requires (!(requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...))
  370. {
  371. return (std::get<I>(t).keep_step(st, z, j) && ...);
  372. }
  373. constexpr static bool
  374. keep_step(dim_t st, int z, int j)
  375. requires (requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...)
  376. {
  377. return (std::decay_t<P>::keep_step(st, z, j) && ...);
  378. }
  379. constexpr auto save() const { return std::make_tuple(std::get<I>(t).save() ...); }
  380. constexpr void load(auto const & pp) { ((std::get<I>(t).load(std::get<I>(pp))), ...); }
  381. constexpr void mov(auto const & s) { ((std::get<I>(t).mov(std::get<I>(s))), ...); }
  382. };
  383. // ---------------------------
  384. // reframe
  385. // ---------------------------
  386. // Transpose variant for IteratorConcepts. As in transpose(), one names the destination axis for
  387. // each original axis. However, axes may not be repeated. Used in the rank conjunction below.
  388. template <dim_t N, class T> constexpr T samestep = N;
  389. template <dim_t N, class ... T> constexpr std::tuple<T ...> samestep<N, std::tuple<T ...>> = { samestep<N, T> ... };
  390. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  391. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  392. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  393. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  394. // TODO invalid for ANY, since Dest is compile time. [ra7]
  395. template <class Dest, IteratorConcept A>
  396. struct Reframe
  397. {
  398. A a;
  399. constexpr static int orig(int k) { return mp::int_list_index<Dest>(k); }
  400. consteval static rank_t rank() { return 1+mp::fold<mp::max, ic_t<-1>, Dest>::value; }
  401. constexpr static dim_t len_s(int k)
  402. {
  403. int l=orig(k);
  404. return l>=0 ? std::decay_t<A>::len_s(l) : BAD;
  405. }
  406. constexpr dim_t
  407. len(int k) const
  408. {
  409. int l=orig(k);
  410. return l>=0 ? a.len(l) : BAD;
  411. }
  412. constexpr auto
  413. step(int k) const
  414. {
  415. int l=orig(k);
  416. return l>=0 ? a.step(l) : samestep<0, decltype(a.step(l))>;
  417. }
  418. constexpr void
  419. adv(rank_t k, dim_t d)
  420. {
  421. int l=orig(k);
  422. if (l>=0) { a.adv(l, d); }
  423. }
  424. constexpr bool
  425. keep_step(dim_t st, int z, int j) const
  426. {
  427. int wz=orig(z), wj=orig(j);
  428. return wz>=0 && wj>=0 && a.keep_step(st, wz, wj);
  429. }
  430. constexpr decltype(auto)
  431. at(auto const & i) const
  432. {
  433. return a.at(mp::map_indices<dim_t, Dest>(i));
  434. }
  435. constexpr decltype(auto) operator*() const { return *a; }
  436. constexpr auto save() const { return a.save(); }
  437. constexpr void load(auto const & p) { a.load(p); }
  438. // FIXME only if Dest preserves axis order, which is how wrank works, but this limitation should be explicit.
  439. constexpr void mov(auto const & s) { a.mov(s); }
  440. };
  441. // Optimize no-op case. TODO If A is CellBig, etc. beat Dest on it, same for eventual transpose_expr<>.
  442. template <class Dest, class A>
  443. constexpr decltype(auto)
  444. reframe(A && a)
  445. {
  446. if constexpr (std::is_same_v<Dest, mp::iota<1+mp::fold<mp::max, ic_t<-1>, Dest>::value>>) {
  447. return RA_FWD(a);
  448. } else {
  449. return Reframe<Dest, A> { RA_FWD(a) };
  450. }
  451. }
  452. // ---------------------------
  453. // verbs and rank conjunction
  454. // ---------------------------
  455. template <class cranks_, class Op_>
  456. struct Verb
  457. {
  458. using cranks = cranks_;
  459. using Op = Op_;
  460. Op op;
  461. };
  462. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  463. template <class cranks, class Op>
  464. constexpr auto
  465. wrank(cranks cranks_, Op && op) { return Verb<cranks, Op> { RA_FWD(op) }; }
  466. template <rank_t ... crank, class Op>
  467. constexpr auto
  468. wrank(Op && op) { return Verb<mp::int_list<crank ...>, Op> { RA_FWD(op) }; }
  469. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  470. struct Framematch_def;
  471. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  472. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  473. template <class A, class B>
  474. struct max_i
  475. {
  476. constexpr static int value = (A::value == choose_rank(A::value, B::value)) ? 0 : 1;
  477. };
  478. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  479. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  480. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  481. {
  482. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  483. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  484. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  485. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  486. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  487. using R = typename FM::R;
  488. template <class VV> constexpr static decltype(auto) op(VV && v) { return FM::op(RA_FWD(v).op); } // cf [ra31]
  489. };
  490. // Terminal case where V doesn't have rank (is a raw op()).
  491. template <class V, class ... Ti, class ... Ri, rank_t skip>
  492. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  493. {
  494. static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  495. // TODO -crank::value when the actual verb rank is used (eg to use CellBig<... that_rank> instead of just begin()).
  496. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  497. template <class VV> constexpr static decltype(auto) op(VV && v) { return RA_FWD(v); }
  498. };
  499. // ---------------
  500. // explicit agreement checks
  501. // ---------------
  502. constexpr bool
  503. agree(auto && ... p) { return agree_(ra::start(RA_FWD(p)) ...); }
  504. // 0: fail, 1: rt, 2: pass
  505. constexpr int
  506. agree_s(auto && ... p) { return agree_s_(ra::start(RA_FWD(p)) ...); }
  507. template <class Op, class ... P> requires (is_verb<Op>)
  508. constexpr bool
  509. agree_op(Op && op, P && ... p) { return agree_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...); }
  510. template <class Op, class ... P> requires (!is_verb<Op>)
  511. constexpr bool
  512. agree_op(Op && op, P && ... p) { return agree(RA_FWD(p) ...); }
  513. template <class ... P>
  514. constexpr bool
  515. agree_(P && ... p) { return (Match<false, std::tuple<P ...>> { RA_FWD(p) ... }).check(); }
  516. template <class ... P>
  517. constexpr int
  518. agree_s_(P && ... p) { return Match<false, std::tuple<P ...>>::check_s(); }
  519. template <class V, class ... T, int ... i>
  520. constexpr bool
  521. agree_verb(mp::int_list<i ...>, V && v, T && ... t)
  522. {
  523. using FM = Framematch<V, std::tuple<T ...>>;
  524. return agree_op(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(ra::start(RA_FWD(t))) ...);
  525. }
  526. // ---------------------------
  527. // operator expression
  528. // ---------------------------
  529. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  530. template <class Op, IteratorConcept ... P, int ... I>
  531. struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  532. {
  533. using Match_ = Match<true, std::tuple<P ...>>;
  534. using Match_::t, Match_::rs, Match_::rank;
  535. Op op;
  536. constexpr Expr(Op op_, P ... p_): Match_(p_ ...), op(op_) {} // [ra1]
  537. RA_ASSIGNOPS_SELF(Expr)
  538. RA_ASSIGNOPS_DEFAULT_SET
  539. constexpr decltype(auto) at(auto const & j) const { return std::invoke(op, std::get<I>(t).at(j) ...); }
  540. constexpr decltype(auto) operator*() const { return std::invoke(op, *std::get<I>(t) ...); }
  541. // needed for rs==ANY, which don't decay to scalar when used as operator arguments.
  542. constexpr
  543. operator decltype(std::invoke(op, *std::get<I>(t) ...)) () const
  544. {
  545. if constexpr (0!=rs && (1!=rs || 1!=size_s<Expr>())) { // for coord types; so ct only
  546. static_assert(rs==ANY);
  547. RA_CHECK(0==rank(), "Bad scalar conversion from shape [", ra::noshape, ra::shape(*this), "].");
  548. }
  549. return *(*this);
  550. }
  551. };
  552. template <class Op, IteratorConcept ... P>
  553. constexpr bool is_special_def<Expr<Op, std::tuple<P ...>>> = (is_special<P> || ...);
  554. template <class V, class ... T, int ... i>
  555. constexpr auto
  556. expr_verb(mp::int_list<i ...>, V && v, T && ... t)
  557. {
  558. using FM = Framematch<V, std::tuple<T ...>>;
  559. return expr(FM::op(RA_FWD(v)), reframe<mp::ref<typename FM::R, i>>(RA_FWD(t)) ...);
  560. }
  561. template <class Op, class ... P>
  562. constexpr auto
  563. expr(Op && op, P && ... p)
  564. {
  565. if constexpr (is_verb<Op>) {
  566. return expr_verb(mp::iota<sizeof...(P)> {}, RA_FWD(op), RA_FWD(p) ...);
  567. } else {
  568. return Expr<Op, std::tuple<P ...>> { RA_FWD(op), RA_FWD(p) ... };
  569. }
  570. }
  571. constexpr auto
  572. map(auto && op, auto && ... a) { return expr(RA_FWD(op), start(RA_FWD(a)) ...); }
  573. // ---------------------------
  574. // pick
  575. // ---------------------------
  576. template <class T, class J> struct pick_at_type;
  577. template <class ... P, class J> struct pick_at_type<std::tuple<P ...>, J>
  578. {
  579. using type = std::common_reference_t<decltype(std::declval<P>().at(std::declval<J>())) ...>;
  580. };
  581. template <std::size_t I, class T, class J>
  582. constexpr pick_at_type<mp::drop1<std::decay_t<T>>, J>::type
  583. pick_at(std::size_t p0, T && t, J const & j)
  584. {
  585. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  586. if constexpr (I < N) {
  587. return (p0==I) ? std::get<I+1>(t).at(j) : pick_at<I+1>(p0, t, j);
  588. } else {
  589. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  590. }
  591. }
  592. template <class T> struct pick_star_type;
  593. template <class ... P> struct pick_star_type<std::tuple<P ...>>
  594. {
  595. using type = std::common_reference_t<decltype(*std::declval<P>()) ...>;
  596. };
  597. template <std::size_t I, class T>
  598. constexpr pick_star_type<mp::drop1<std::decay_t<T>>>::type
  599. pick_star(std::size_t p0, T && t)
  600. {
  601. constexpr std::size_t N = mp::len<std::decay_t<T>> - 1;
  602. if constexpr (I < N) {
  603. return (p0==I) ? *(std::get<I+1>(t)) : pick_star<I+1>(p0, t);
  604. } else {
  605. RA_CHECK(p0 < N, "Bad pick ", p0, " with ", N, " arguments."); std::abort();
  606. }
  607. }
  608. template <class T, class K=mp::iota<mp::len<T>>> struct Pick;
  609. template <IteratorConcept ... P, int ... I>
  610. struct Pick<std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  611. {
  612. using Match_ = Match<true, std::tuple<P ...>>;
  613. using Match_::t, Match_::rs, Match_::rank;
  614. static_assert(sizeof...(P)>1);
  615. constexpr Pick(P ... p_): Match_(p_ ...) {} // [ra1]
  616. RA_ASSIGNOPS_SELF(Pick)
  617. RA_ASSIGNOPS_DEFAULT_SET
  618. constexpr decltype(auto) at(auto const & j) const { return pick_at<0>(std::get<0>(t).at(j), t, j); }
  619. constexpr decltype(auto) operator*() const { return pick_star<0>(*std::get<0>(t), t); }
  620. // needed for xpr with rs==ANY, which don't decay to scalar when used as operator arguments.
  621. constexpr
  622. operator decltype(pick_star<0>(*std::get<0>(t), t)) () const
  623. {
  624. if constexpr (0!=rs && (1!=rs || 1!=size_s<Pick>())) { // for coord types; so ct only
  625. static_assert(rs==ANY);
  626. RA_CHECK(0==rank(), "Bad scalar conversion from shape [", ra::noshape, ra::shape(*this), "].");
  627. }
  628. return *(*this);
  629. }
  630. };
  631. template <IteratorConcept ... P>
  632. constexpr bool is_special_def<Pick<std::tuple<P ...>>> = (is_special<P> || ...);
  633. template <class ... P>
  634. Pick(P && ... p) -> Pick<std::tuple<P ...>>;
  635. constexpr auto
  636. pick(auto && ... p) { return Pick { start(RA_FWD(p)) ... }; }
  637. } // namespace ra