ply.hh 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression traversal.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Make traversal order a parameter, some operations (e.g. output, ravel) require specific orders.
  9. // TODO Better heuristic for traversal order.
  10. // TODO Tiling, etc. (see eval.cc in Blitz++).
  11. // TODO Unit step case?
  12. // TODO std::execution::xxx-policy
  13. // TODO Validate output argument strides.
  14. #pragma once
  15. #include "expr.hh"
  16. namespace ra {
  17. template <class A>
  18. constexpr decltype(auto)
  19. VALUE(A && a)
  20. {
  21. if constexpr (is_scalar<A>) {
  22. return RA_FWD(a); // avoid dangling temp in this case [ra8] (?? maybe unnecessary)
  23. } else if constexpr (is_iterator<A>) {
  24. return *a; // no need to start() for one
  25. } else {
  26. return *(ra::start(RA_FWD(a)));
  27. }
  28. }
  29. // FIXME do we really want to drop const? See use in concrete_type.
  30. template <class A> using value_t = std::decay_t<decltype(VALUE(std::declval<A>()))>;
  31. // ---------------------
  32. // replace Len in expr tree.
  33. // ---------------------
  34. template <>
  35. constexpr bool has_len_def<Len> = true;
  36. template <IteratorConcept ... P>
  37. constexpr bool has_len_def<Pick<std::tuple<P ...>>> = (has_len<P> || ...);
  38. template <class Op, IteratorConcept ... P>
  39. constexpr bool has_len_def<Expr<Op, std::tuple<P ...>>> = (has_len<P> || ...);
  40. template <int w, class N, class O, class S>
  41. constexpr bool has_len_def<Iota<w, N, O, S>> = (has_len<N> || has_len<O> || has_len<S>);
  42. template <class I, class N>
  43. constexpr bool has_len_def<Ptr<I, N>> = has_len<N>;
  44. template <class E_>
  45. struct WithLen
  46. {
  47. // constant/scalar appear in Iota args. dots_t and insert_t appear in subscripts. FIXME restrict to known cases
  48. constexpr static decltype(auto)
  49. f(auto ln, auto && e)
  50. {
  51. return RA_FWD(e);
  52. }
  53. };
  54. template <>
  55. struct WithLen<Len>
  56. {
  57. template <class Ln, class E>
  58. constexpr static decltype(auto)
  59. f(Ln ln, E && e)
  60. {
  61. return Scalar<Ln>(ln);
  62. }
  63. };
  64. template <class Op, IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  65. struct WithLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
  66. {
  67. constexpr static decltype(auto)
  68. f(auto ln, auto && e)
  69. {
  70. return expr(RA_FWD(e).op, WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  71. }
  72. };
  73. template <IteratorConcept ... P, int ... I> requires (has_len<P> || ...)
  74. struct WithLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
  75. {
  76. constexpr static decltype(auto)
  77. f(auto ln, auto && e)
  78. {
  79. return pick(WithLen<std::decay_t<P>>::f(ln, std::get<I>(RA_FWD(e).t)) ...);
  80. }
  81. };
  82. template <int w, class N, class O, class S> requires (has_len<N> || has_len<O> || has_len<S>)
  83. struct WithLen<Iota<w, N, O, S>>
  84. {
  85. constexpr static decltype(auto)
  86. f(auto ln, auto && e)
  87. {
  88. // final iota types must be either is_constant or is_scalar.
  89. return iota<w>(VALUE(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)),
  90. VALUE(WithLen<std::decay_t<O>>::f(ln, RA_FWD(e).i)),
  91. VALUE(WithLen<std::decay_t<S>>::f(ln, RA_FWD(e).s)));
  92. }
  93. };
  94. template <class I, class N> requires (has_len<N>)
  95. struct WithLen<Ptr<I, N>>
  96. {
  97. constexpr static decltype(auto)
  98. f(auto ln, auto && e)
  99. {
  100. return ptr(RA_FWD(e).i, VALUE(WithLen<std::decay_t<N>>::f(ln, RA_FWD(e).n)));
  101. }
  102. };
  103. template <class Ln, class E>
  104. constexpr decltype(auto)
  105. with_len(Ln ln, E && e)
  106. {
  107. static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
  108. return WithLen<std::decay_t<E>>::f(ln, RA_FWD(e));
  109. }
  110. // --------------
  111. // ply, run time order/rank.
  112. // --------------
  113. struct Nop {};
  114. // step() must give 0 for k>=their own rank, to allow frame matching.
  115. template <IteratorConcept A, class Early = Nop>
  116. constexpr auto
  117. ply_ravel(A && a, Early && early = Nop {})
  118. {
  119. rank_t rank = ra::rank(a);
  120. // must avoid 0-length vlas [ra40].
  121. if (0>=rank) {
  122. if (0>rank) [[unlikely]] { std::abort(); }
  123. if constexpr (requires {early.def;}) {
  124. return (*a).value_or(early.def);
  125. } else {
  126. *a;
  127. return;
  128. }
  129. }
  130. // inside first. FIXME better heuristic - but first need a way to force row-major
  131. rank_t order[rank];
  132. for (rank_t i=0; i<rank; ++i) {
  133. order[i] = rank-1-i;
  134. }
  135. dim_t sha[rank], ind[rank] = {};
  136. // find outermost compact dim.
  137. rank_t * ocd = order;
  138. dim_t ss = a.len(*ocd);
  139. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  140. ss *= a.len(*ocd);
  141. }
  142. for (int k=0; k<rank; ++k) {
  143. // ss takes care of the raveled dimensions ss.
  144. if (0>=(sha[k]=a.len(ocd[k]))) {
  145. if (0>sha[k]) [[unlikely]] { std::abort(); }
  146. if constexpr (requires {early.def;}) {
  147. return early.def;
  148. } else {
  149. return;
  150. }
  151. }
  152. }
  153. auto ss0 = a.step(order[0]);
  154. for (;;) {
  155. auto place = a.save();
  156. for (dim_t s=ss; --s>=0; a.mov(ss0)) {
  157. if constexpr (requires {early.def;}) {
  158. if (auto stop = *a) {
  159. return stop.value();
  160. }
  161. } else {
  162. *a;
  163. }
  164. }
  165. a.load(place); // FIXME wasted if k=0. Cf test/iota.cc
  166. for (int k=0; ; ++k) {
  167. if (k>=rank) {
  168. if constexpr (requires {early.def;}) {
  169. return early.def;
  170. } else {
  171. return;
  172. }
  173. } else if (++ind[k]<sha[k]) {
  174. a.adv(ocd[k], 1);
  175. break;
  176. } else {
  177. ind[k] = 0;
  178. a.adv(ocd[k], 1-sha[k]);
  179. }
  180. }
  181. }
  182. }
  183. // -------------------------
  184. // ply, compile time order/rank.
  185. // -------------------------
  186. template <auto order, int k, int urank, class A, class S, class Early>
  187. constexpr auto
  188. subply(A & a, dim_t s, S const & ss0, Early & early)
  189. {
  190. if constexpr (k < urank) {
  191. auto place = a.save();
  192. for (; --s>=0; a.mov(ss0)) {
  193. if constexpr (requires {early.def;}) {
  194. if (auto stop = *a) {
  195. return stop;
  196. }
  197. } else {
  198. *a;
  199. }
  200. }
  201. a.load(place); // FIXME wasted if k was 0 at the top
  202. } else {
  203. dim_t size = a.len(order[k]); // TODO precompute above
  204. for (dim_t i=0; i<size; ++i) {
  205. if constexpr (requires {early.def;}) {
  206. if (auto stop = subply<order, k-1, urank>(a, s, ss0, early)) {
  207. return stop;
  208. }
  209. } else {
  210. subply<order, k-1, urank>(a, s, ss0, early);
  211. }
  212. a.adv(order[k], 1);
  213. }
  214. a.adv(order[k], -size);
  215. }
  216. if constexpr (requires {early.def;}) {
  217. return static_cast<decltype(*a)>(std::nullopt);
  218. } else {
  219. return;
  220. }
  221. }
  222. // possible pessimization in ply_fixed(). See bench-dot [ra43]
  223. #ifndef RA_STATIC_UNROLL
  224. #define RA_STATIC_UNROLL 0
  225. #endif
  226. template <IteratorConcept A, class Early = Nop>
  227. constexpr decltype(auto)
  228. ply_fixed(A && a, Early && early = Nop {})
  229. {
  230. constexpr rank_t rank = rank_s<A>();
  231. static_assert(0<=rank, "ply_fixed needs static rank");
  232. // inside first. FIXME better heuristic - but first need a way to force row-major
  233. constexpr /* static P2647 gcc13 */ auto order = mp::tuple_values<int, mp::reverse<mp::iota<rank>>>();
  234. if constexpr (0==rank) {
  235. if constexpr (requires {early.def;}) {
  236. return (*a).value_or(early.def);
  237. } else {
  238. *a;
  239. return;
  240. }
  241. } else {
  242. auto ss0 = a.step(order[0]);
  243. // static keep_step implies all else is static.
  244. if constexpr (RA_STATIC_UNROLL && rank>1 && requires (dim_t st, rank_t z, rank_t j) { A::keep_step(st, z, j); }) {
  245. // find outermost compact dim.
  246. constexpr auto sj = [&order]
  247. {
  248. dim_t ss = A::len_s(order[0]);
  249. int j = 1;
  250. for (; j<rank && A::keep_step(ss, order[0], order[j]); ++j) {
  251. ss *= A::len_s(order[j]);
  252. }
  253. return std::make_tuple(ss, j);
  254. } ();
  255. if constexpr (requires {early.def;}) {
  256. return (subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early)).value_or(early.def);
  257. } else {
  258. subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early);
  259. }
  260. } else {
  261. // not worth unrolling.
  262. if constexpr (requires {early.def;}) {
  263. return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
  264. } else {
  265. subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
  266. }
  267. }
  268. }
  269. }
  270. // ---------------------------
  271. // ply, best for each type
  272. // ---------------------------
  273. template <IteratorConcept A, class Early = Nop>
  274. constexpr decltype(auto)
  275. ply(A && a, Early && early = Nop {})
  276. {
  277. static_assert(!has_len<A>, "len used outside subscript context.");
  278. static_assert(0<=rank_s<A>() || ANY==rank_s<A>());
  279. if constexpr (ANY==size_s<A>()) {
  280. return ply_ravel(RA_FWD(a), RA_FWD(early));
  281. } else {
  282. return ply_fixed(RA_FWD(a), RA_FWD(early));
  283. }
  284. }
  285. constexpr void
  286. for_each(auto && op, auto && ... a) { ply(map(RA_FWD(op), RA_FWD(a) ...)); }
  287. // ---------------------------
  288. // ply, short-circuiting
  289. // ---------------------------
  290. template <class T> struct Default { T def; };
  291. template <class T> Default(T &&) -> Default<T>;
  292. constexpr decltype(auto)
  293. early(IteratorConcept auto && a, auto && def) { return ply(RA_FWD(a), Default { RA_FWD(def) }); }
  294. // --------------------
  295. // STLIterator for CellSmall / CellBig. FIXME make it work for any IteratorConcept.
  296. // --------------------
  297. template <class Iterator>
  298. struct STLIterator
  299. {
  300. using difference_type = dim_t;
  301. using value_type = typename Iterator::value_type;
  302. using shape_type = decltype(ra::shape(std::declval<Iterator>())); // std::array or std::vector
  303. Iterator ii;
  304. shape_type ind;
  305. STLIterator(Iterator const & ii_)
  306. : ii(ii_),
  307. ind([&] {
  308. if constexpr (ANY==rank_s<Iterator>()) {
  309. return shape_type(rank(ii), 0);
  310. } else {
  311. return shape_type {0};
  312. }
  313. }())
  314. {
  315. // [ra12] mark empty range. FIXME make 0==size() more efficient.
  316. if (0==ra::size(ii)) {
  317. ii.c.cp = nullptr;
  318. }
  319. }
  320. constexpr STLIterator &
  321. operator=(STLIterator const & it)
  322. {
  323. ind = it.ind;
  324. ii.Iterator::~Iterator(); // no-op except for View<ANY>. Still...
  325. new (&ii) Iterator(it.ii); // avoid ii = it.ii [ra11]
  326. return *this;
  327. }
  328. bool operator==(std::default_sentinel_t end) const { return !(ii.c.cp); }
  329. decltype(auto) operator*() const { return *ii; }
  330. constexpr void
  331. cube_next(rank_t k)
  332. {
  333. for (; k>=0; --k) {
  334. if (++ind[k]<ii.len(k)) {
  335. ii.adv(k, 1);
  336. return;
  337. } else {
  338. ind[k] = 0;
  339. ii.adv(k, 1-ii.len(k));
  340. }
  341. }
  342. ii.c.cp = nullptr;
  343. }
  344. template <int k>
  345. constexpr void
  346. cube_next()
  347. {
  348. if constexpr (k>=0) {
  349. if (++ind[k]<ii.len(k)) {
  350. ii.adv(k, 1);
  351. } else {
  352. ind[k] = 0;
  353. ii.adv(k, 1-ii.len(k));
  354. cube_next<k-1>();
  355. }
  356. return;
  357. }
  358. ii.c.cp = nullptr;
  359. }
  360. STLIterator & operator++()
  361. {
  362. if constexpr (ANY==rank_s<Iterator>()) {
  363. cube_next(rank(ii)-1);
  364. } else {
  365. cube_next<rank_s<Iterator>()-1>();
  366. }
  367. return *this;
  368. }
  369. // required by std::input_or_output_iterator
  370. STLIterator & operator++(int) { auto old = *this; ++(*this); return old; }
  371. };
  372. // ---------------------------
  373. // i/o
  374. // ---------------------------
  375. // TODO once ply_ravel lets one specify row-major, reuse that.
  376. template <class A>
  377. inline std::ostream &
  378. operator<<(std::ostream & o, FormatArray<A> const & fa)
  379. {
  380. static_assert(!has_len<A>, "len used outside subscript context.");
  381. static_assert(BAD!=size_s<A>(), "Cannot print undefined size expr.");
  382. auto a = ra::start(fa.a); // [ra35]
  383. rank_t const rank = ra::rank(a);
  384. auto sha = shape(a);
  385. if (withshape==fa.shape || (defaultshape==fa.shape && size_s(a)==ANY)) {
  386. o << sha << '\n';
  387. }
  388. for (rank_t k=0; k<rank; ++k) {
  389. if (0==sha[k]) {
  390. return o;
  391. }
  392. }
  393. auto ind = sha; for_each([](auto & s) { s=0; }, ind);
  394. for (;;) {
  395. o << *a;
  396. for (int k=0; ; ++k) {
  397. if (k>=rank) {
  398. return o;
  399. } else if (++ind[rank-1-k]<sha[rank-1-k]) {
  400. a.adv(rank-1-k, 1);
  401. switch (k) {
  402. case 0: o << fa.sep0; break;
  403. case 1: o << fa.sep1; break;
  404. default: std::fill_n(std::ostream_iterator<char const *>(o, ""), k, fa.sep2);
  405. }
  406. break;
  407. } else {
  408. ind[rank-1-k] = 0;
  409. a.adv(rank-1-k, 1-sha[rank-1-k]);
  410. }
  411. }
  412. }
  413. }
  414. // Possibly read shape, possibly allocate.
  415. template <class C> requires (ANY!=size_s<C>() && !is_scalar<C>)
  416. inline std::istream &
  417. operator>>(std::istream & i, C & c)
  418. {
  419. for (auto & ci: c) { i >> ci; }
  420. return i;
  421. }
  422. template <class T, class A>
  423. inline std::istream &
  424. operator>>(std::istream & i, std::vector<T, A> & c)
  425. {
  426. if (dim_t n; i >> n) {
  427. RA_CHECK(n>=0, "Negative length in input [", n, "].");
  428. std::vector<T, A> cc(n);
  429. swap(c, cc);
  430. for (auto & ci: c) { i >> ci; }
  431. }
  432. return i;
  433. }
  434. template <class C> requires (ANY==size_s<C>() && !std::is_convertible_v<C, std::string_view>)
  435. inline std::istream &
  436. operator>>(std::istream & i, C & c)
  437. {
  438. if (decltype(shape(c)) s; i >> s) {
  439. RA_CHECK(every(start(s)>=0), "Negative length in input [", noshape, s, "].");
  440. C cc(s, ra::none);
  441. swap(c, cc);
  442. for (auto & ci: c) { i >> ci; }
  443. }
  444. return i;
  445. }
  446. } // namespace ra