ply.hh 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression traversal.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Make traversal order a parameter, some operations (e.g. output, ravel) require specific orders.
  9. // TODO Better heuristic for traversal order.
  10. // TODO Tiling, etc. (see eval.cc in Blitz++).
  11. // TODO Unit step case?
  12. // TODO std::execution::xxx-policy
  13. // TODO Validate output argument strides.
  14. #pragma once
  15. #include "expr.hh"
  16. namespace ra {
  17. template <class A>
  18. constexpr decltype(auto)
  19. FLAT(A && a)
  20. {
  21. if constexpr (is_scalar<A>) {
  22. return RA_FWD(a); // avoid dangling temp in this case [ra8] (?? maybe unnecessary)
  23. } else if constexpr (is_iterator<A>) {
  24. return *a; // no need to start() for one
  25. } else {
  26. return *(ra::start(RA_FWD(a)));
  27. }
  28. }
  29. // FIXME do we really want to drop const? See use in concrete_type.
  30. template <class A> using value_t = std::decay_t<decltype(FLAT(std::declval<A>()))>;
  31. // ---------------------
  32. // replace Len in expr tree.
  33. // ---------------------
  34. template <>
  35. constexpr bool has_len_def<Len> = true;
  36. template <IteratorConcept ... P>
  37. constexpr bool has_len_def<Pick<std::tuple<P ...>>> = (has_len<P> || ...);
  38. template <class Op, IteratorConcept ... P>
  39. constexpr bool has_len_def<Expr<Op, std::tuple<P ...>>> = (has_len<P> || ...);
  40. template <int w, class N, class O, class S>
  41. constexpr bool has_len_def<Iota<w, N, O, S>> = (has_len<N> || has_len<O> || has_len<S>);
  42. template <class I, class N>
  43. constexpr bool has_len_def<Ptr<I, N>> = has_len<N>;
  44. template <class E_>
  45. struct WithLen
  46. {
  47. // constant/scalar appear in Iota args. dots_t and insert_t appear in subscripts. FIXME restrict to known cases
  48. template <class Ln, class E> constexpr static decltype(auto)
  49. f(Ln ln, E && e)
  50. {
  51. return RA_FWD(e);
  52. }
  53. };
  54. template <>
  55. struct WithLen<Len>
  56. {
  57. template <class Ln, class E> constexpr static decltype(auto)
  58. f(Ln ln, E && e)
  59. {
  60. return Scalar<Ln>(ln);
  61. }
  62. };
  63. template <class Op, IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  64. struct WithLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
  65. {
  66. template <class Ln, class E> constexpr static decltype(auto)
  67. f(Ln ln, E && e)
  68. {
  69. return expr(RA_FWD(e).op, WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  70. }
  71. };
  72. template <IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  73. struct WithLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
  74. {
  75. template <class Ln, class E> constexpr static decltype(auto)
  76. f(Ln ln, E && e)
  77. {
  78. return pick(WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  79. }
  80. };
  81. template <int w, class N, class O, class S> requires (has_len<N> || has_len<O> || has_len<S>)
  82. struct WithLen<Iota<w, N, O, S>>
  83. {
  84. template <class Ln, class E> constexpr static decltype(auto)
  85. f(Ln ln, E && e)
  86. {
  87. // usable iota types must be either is_constant or is_scalar.
  88. return iota<w>(FLAT(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)),
  89. FLAT(WithLen<std::decay_t<O>>::f(ln, RA_FWD(e).i)),
  90. FLAT(WithLen<std::decay_t<S>>::f(ln, RA_FWD(e).s)));
  91. }
  92. };
  93. template <class I, class N> requires (has_len<N>)
  94. struct WithLen<Ptr<I, N>>
  95. {
  96. template <class Ln, class E> constexpr static decltype(auto)
  97. f(Ln ln, E && e)
  98. {
  99. return ptr(RA_FWD(e).i, FLAT(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)));
  100. }
  101. };
  102. template <class Ln, class E>
  103. constexpr decltype(auto)
  104. with_len(Ln ln, E && e)
  105. {
  106. static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
  107. return WithLen<std::decay_t<E>>::f(ln, RA_FWD(e));
  108. }
  109. // --------------
  110. // ply, run time order/rank.
  111. // --------------
  112. struct Nop {};
  113. // step() must give 0 for k>=their own rank, to allow frame matching.
  114. template <IteratorConcept A, class Early = Nop>
  115. constexpr auto
  116. ply_ravel(A && a, Early && early = Nop {})
  117. {
  118. rank_t rank = ra::rank(a);
  119. // must avoid 0-length vlas [ra40].
  120. if (0>=rank) {
  121. if (0>rank) [[unlikely]] { std::abort(); }
  122. if constexpr (requires {early.def;}) {
  123. return (*a).value_or(early.def);
  124. } else {
  125. *a;
  126. return;
  127. }
  128. }
  129. // inside first. FIXME better heuristic - but first need a way to force row-major
  130. rank_t order[rank];
  131. for (rank_t i=0; i<rank; ++i) {
  132. order[i] = rank-1-i;
  133. }
  134. dim_t sha[rank], ind[rank] = {};
  135. // find outermost compact dim.
  136. rank_t * ocd = order;
  137. dim_t ss = a.len(*ocd);
  138. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  139. ss *= a.len(*ocd);
  140. }
  141. for (int k=0; k<rank; ++k) {
  142. // ss takes care of the raveled dimensions ss.
  143. if (0>=(sha[k]=a.len(ocd[k]))) {
  144. if (0>sha[k]) [[unlikely]] { std::abort(); }
  145. if constexpr (requires {early.def;}) {
  146. return early.def;
  147. } else {
  148. return;
  149. }
  150. }
  151. }
  152. auto ss0 = a.step(order[0]);
  153. for (;;) {
  154. auto place = a.save();
  155. for (dim_t s=ss; --s>=0; a.mov(ss0)) {
  156. if constexpr (requires {early.def;}) {
  157. if (auto stop = *a) {
  158. return stop.value();
  159. }
  160. } else {
  161. *a;
  162. }
  163. }
  164. a.load(place); // FIXME wasted if k=0. Cf test/iota.cc
  165. for (int k=0; ; ++k) {
  166. if (k>=rank) {
  167. if constexpr (requires {early.def;}) {
  168. return early.def;
  169. } else {
  170. return;
  171. }
  172. } else if (++ind[k]<sha[k]) {
  173. a.adv(ocd[k], 1);
  174. break;
  175. } else {
  176. ind[k] = 0;
  177. a.adv(ocd[k], 1-sha[k]);
  178. }
  179. }
  180. }
  181. }
  182. // -------------------------
  183. // ply, compile time order/rank.
  184. // -------------------------
  185. template <auto order, int k, int urank, class A, class S, class Early>
  186. constexpr auto
  187. subply(A & a, dim_t s, S const & ss0, Early & early)
  188. {
  189. if constexpr (k < urank) {
  190. auto place = a.save();
  191. for (; --s>=0; a.mov(ss0)) {
  192. if constexpr (requires {early.def;}) {
  193. if (auto stop = *a) {
  194. return stop;
  195. }
  196. } else {
  197. *a;
  198. }
  199. }
  200. a.load(place); // FIXME wasted if k was 0 at the top
  201. } else {
  202. dim_t size = a.len(order[k]); // TODO precompute above
  203. for (dim_t i=0; i<size; ++i) {
  204. if constexpr (requires {early.def;}) {
  205. if (auto stop = subply<order, k-1, urank>(a, s, ss0, early)) {
  206. return stop;
  207. }
  208. } else {
  209. subply<order, k-1, urank>(a, s, ss0, early);
  210. }
  211. a.adv(order[k], 1);
  212. }
  213. a.adv(order[k], -size);
  214. }
  215. if constexpr (requires {early.def;}) {
  216. return static_cast<decltype(*a)>(std::nullopt);
  217. } else {
  218. return;
  219. }
  220. }
  221. // possible pessimization in ply_fixed(). See bench-dot [ra43]
  222. #ifndef RA_STATIC_UNROLL
  223. #define RA_STATIC_UNROLL 0
  224. #endif
  225. template <IteratorConcept A, class Early = Nop>
  226. constexpr decltype(auto)
  227. ply_fixed(A && a, Early && early = Nop {})
  228. {
  229. constexpr rank_t rank = rank_s<A>();
  230. static_assert(0<=rank, "ply_fixed needs static rank");
  231. // inside first. FIXME better heuristic - but first need a way to force row-major
  232. constexpr /* static P2647 gcc13 */ auto order = mp::tuple_values<int, mp::reverse<mp::iota<rank>>>();
  233. if constexpr (0==rank) {
  234. if constexpr (requires {early.def;}) {
  235. return (*a).value_or(early.def);
  236. } else {
  237. *a;
  238. return;
  239. }
  240. } else {
  241. auto ss0 = a.step(order[0]);
  242. // static keep_step implies all else is static.
  243. if constexpr (RA_STATIC_UNROLL && rank>1 && requires (dim_t st, rank_t z, rank_t j) { A::keep_step(st, z, j); }) {
  244. // find outermost compact dim.
  245. constexpr auto sj = [&order]
  246. {
  247. dim_t ss = A::len_s(order[0]);
  248. int j = 1;
  249. for (; j<rank && A::keep_step(ss, order[0], order[j]); ++j) {
  250. ss *= A::len_s(order[j]);
  251. }
  252. return std::make_tuple(ss, j);
  253. } ();
  254. if constexpr (requires {early.def;}) {
  255. return (subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early)).value_or(early.def);
  256. } else {
  257. subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early);
  258. }
  259. } else {
  260. // not worth unrolling.
  261. if constexpr (requires {early.def;}) {
  262. return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
  263. } else {
  264. subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
  265. }
  266. }
  267. }
  268. }
  269. // ---------------------------
  270. // ply, best for each type
  271. // ---------------------------
  272. template <IteratorConcept A, class Early = Nop>
  273. constexpr decltype(auto)
  274. ply(A && a, Early && early = Nop {})
  275. {
  276. static_assert(!has_len<A>, "len used outside subscript context.");
  277. static_assert(0<=rank_s<A>() || ANY==rank_s<A>());
  278. if constexpr (ANY==size_s<A>()) {
  279. return ply_ravel(RA_FWD(a), RA_FWD(early));
  280. } else {
  281. return ply_fixed(RA_FWD(a), RA_FWD(early));
  282. }
  283. }
  284. template <class Op, class ... A>
  285. constexpr void
  286. for_each(Op && op, A && ... a)
  287. {
  288. ply(map(RA_FWD(op), RA_FWD(a) ...));
  289. }
  290. // ---------------------------
  291. // ply, short-circuiting
  292. // ---------------------------
  293. template <class T> struct Default { T def; };
  294. template <class T> Default(T &&) -> Default<T>;
  295. template <IteratorConcept A, class Def>
  296. constexpr decltype(auto)
  297. early(A && a, Def && def)
  298. {
  299. return ply(RA_FWD(a), Default { RA_FWD(def) });
  300. }
  301. // --------------------
  302. // STLIterator for CellSmall / CellBig. FIXME make it work for any IteratorConcept.
  303. // --------------------
  304. template <class Iterator>
  305. struct STLIterator
  306. {
  307. using difference_type = dim_t;
  308. using value_type = typename Iterator::value_type;
  309. using shape_type = decltype(ra::shape(std::declval<Iterator>())); // std::array or std::vector
  310. Iterator ii;
  311. shape_type ind;
  312. STLIterator(Iterator const & ii_)
  313. : ii(ii_),
  314. ind([&] {
  315. if constexpr (ANY==rank_s<Iterator>()) {
  316. return shape_type(rank(ii), 0);
  317. } else {
  318. return shape_type {0};
  319. }
  320. }())
  321. {
  322. // [ra12] mark empty range. FIXME make 0==size() more efficient.
  323. if (0==ra::size(ii)) {
  324. ii.c.cp = nullptr;
  325. }
  326. }
  327. constexpr STLIterator &
  328. operator=(STLIterator const & it)
  329. {
  330. ind = it.ind;
  331. ii.Iterator::~Iterator(); // no-op except for View<ANY>. Still...
  332. new (&ii) Iterator(it.ii); // avoid ii = it.ii [ra11]
  333. return *this;
  334. }
  335. bool operator==(std::default_sentinel_t end) const { return !(ii.c.cp); }
  336. decltype(auto) operator*() const { return *ii; }
  337. constexpr void
  338. cube_next(rank_t k)
  339. {
  340. for (; k>=0; --k) {
  341. if (++ind[k]<ii.len(k)) {
  342. ii.adv(k, 1);
  343. return;
  344. } else {
  345. ind[k] = 0;
  346. ii.adv(k, 1-ii.len(k));
  347. }
  348. }
  349. ii.c.cp = nullptr;
  350. }
  351. template <int k>
  352. constexpr void
  353. cube_next()
  354. {
  355. if constexpr (k>=0) {
  356. if (++ind[k]<ii.len(k)) {
  357. ii.adv(k, 1);
  358. } else {
  359. ind[k] = 0;
  360. ii.adv(k, 1-ii.len(k));
  361. cube_next<k-1>();
  362. }
  363. return;
  364. }
  365. ii.c.cp = nullptr;
  366. }
  367. STLIterator & operator++()
  368. {
  369. if constexpr (ANY==rank_s<Iterator>()) {
  370. cube_next(rank(ii)-1);
  371. } else {
  372. cube_next<rank_s<Iterator>()-1>();
  373. }
  374. return *this;
  375. }
  376. // required by std::input_or_output_iterator
  377. STLIterator & operator++(int) { auto old = *this; ++(*this); return old; }
  378. };
  379. // ---------------------------
  380. // i/o
  381. // ---------------------------
  382. // TODO once ply_ravel lets one specify row-major, reuse that.
  383. template <class A>
  384. inline std::ostream &
  385. operator<<(std::ostream & o, FormatArray<A> const & fa)
  386. {
  387. static_assert(!has_len<A>, "len used outside subscript context.");
  388. static_assert(BAD!=size_s<A>(), "Cannot print undefined size expr.");
  389. auto a = ra::start(fa.a); // [ra35]
  390. rank_t const rank = ra::rank(a);
  391. auto sha = shape(a);
  392. if (withshape==fa.shape || (defaultshape==fa.shape && size_s(a)==ANY)) {
  393. o << sha << '\n';
  394. }
  395. for (rank_t k=0; k<rank; ++k) {
  396. if (0==sha[k]) {
  397. return o;
  398. }
  399. }
  400. auto ind = sha; for_each([](auto & s) { s=0; }, ind);
  401. for (;;) {
  402. o << *a;
  403. for (int k=0; ; ++k) {
  404. if (k>=rank) {
  405. return o;
  406. } else if (++ind[rank-1-k]<sha[rank-1-k]) {
  407. a.adv(rank-1-k, 1);
  408. switch (k) {
  409. case 0: o << fa.sep0; break;
  410. case 1: o << fa.sep1; break;
  411. default: std::fill_n(std::ostream_iterator<char const *>(o, ""), k, fa.sep2);
  412. }
  413. break;
  414. } else {
  415. ind[rank-1-k] = 0;
  416. a.adv(rank-1-k, 1-sha[rank-1-k]);
  417. }
  418. }
  419. }
  420. }
  421. // Static size.
  422. template <class C> requires (ANY!=size_s<C>() && !is_scalar<C>)
  423. inline std::istream &
  424. operator>>(std::istream & i, C & c)
  425. {
  426. for (auto & ci: c) { i >> ci; }
  427. return i;
  428. }
  429. // std::vector needs a special constructor.
  430. template <class T, class A>
  431. inline std::istream &
  432. operator>>(std::istream & i, std::vector<T, A> & c)
  433. {
  434. if (dim_t n; i >> n) {
  435. RA_CHECK(n>=0, "Negative length in input [", n, "].");
  436. std::vector<T, A> cc(n);
  437. swap(c, cc);
  438. for (auto & ci: c) { i >> ci; }
  439. }
  440. return i;
  441. }
  442. // Read shape and possibly allocate.
  443. template <class C> requires (ANY==size_s<C>() && !std::is_convertible_v<C, std::string_view>)
  444. inline std::istream &
  445. operator>>(std::istream & i, C & c)
  446. {
  447. if (decltype(shape(c)) s; i >> s) {
  448. RA_CHECK(every(start(s)>=0), "Negative length in input [", noshape, s, "].");
  449. C cc(s, ra::none);
  450. swap(c, cc);
  451. for (auto & ci: c) { i >> ci; }
  452. }
  453. return i;
  454. }
  455. } // namespace ra