atom.hh 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file atom.hh
  3. /// @brief Terminal nodes for expression templates, and use-as-xpr wrapper.
  4. // (c) Daniel Llorens - 2011-2016, 2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/type.hh"
  11. namespace ra {
  12. template <class V> inline constexpr auto size(V const & v);
  13. template <class V> inline constexpr decltype(auto) shape(V const & v);
  14. template <class C> struct Scalar;
  15. // Separate from Scalar so that operator+=, etc. has the array meaning there.
  16. template <class C>
  17. struct ScalarFlat: public Scalar<C>
  18. {
  19. constexpr void operator+=(dim_t d) const {}
  20. constexpr C & operator*() { return this->c; }
  21. constexpr C const & operator*() const { return this->c; } // [ra39]
  22. };
  23. // Wrap constant for traversal. We still want f(C) to be a specialization in most cases.
  24. template <class C_>
  25. struct Scalar
  26. {
  27. using C = C_;
  28. C c;
  29. constexpr static rank_t rank_s() { return 0; }
  30. constexpr static rank_t rank() { return 0; }
  31. constexpr static dim_t size_s(int k) { return DIM_BAD; }
  32. constexpr static dim_t size(int k) { return DIM_BAD; } // used in shape checks with dyn rank.
  33. template <class I> constexpr C & at(I const & i) { return c; }
  34. constexpr static void adv(rank_t k, dim_t d) {}
  35. constexpr static dim_t stride(int k) { return 0; }
  36. constexpr static bool keep_stride(dim_t st, int z, int j) { return true; }
  37. constexpr decltype(auto) flat() { return static_cast<ScalarFlat<C> &>(*this); }
  38. constexpr decltype(auto) flat() const { return static_cast<ScalarFlat<C> const &>(*this); } // [ra39]
  39. FOR_EACH(RA_DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  40. };
  41. template <class C> inline constexpr auto scalar(C && c) { return Scalar<C> { std::forward<C>(c) }; }
  42. // Wrap foreign vectors.
  43. // FIXME This can handle temporaries and make_a().begin() can't, look out for that.
  44. // FIXME Do we need this class? holding rvalue is the only thing it does over View, and it doesn't handle rank!=1.
  45. template <class V>
  46. requires (requires { ra_traits<V>::size_s; } &&
  47. requires (V v) { { v.begin() } -> std::random_access_iterator; })
  48. struct Vector
  49. {
  50. V v;
  51. decltype(v.begin()) p__;
  52. constexpr dim_t size(int k) const { RA_CHECK(k==0, " k ", k); return ra_traits<V>::size(v); }
  53. constexpr static dim_t size_s(int k) { RA_CHECK(k==0, " k ", k); return ra_traits<V>::size_s(); }
  54. constexpr static rank_t rank() { return 1; }
  55. constexpr static rank_t rank_s() { return 1; };
  56. // see test/ra-9.cc [ra1] for forward() here.
  57. constexpr Vector(V && v_): v(std::forward<V>(v_)), p__(v.begin()) {}
  58. // see [ra35] in test/ra-9.cc. FIXME How about I just hold a ref for any kind of V, like container -> iter.
  59. constexpr Vector(Vector<std::remove_reference_t<V>> const & a): v(std::move(a.v)), p__(v.begin()) { static_assert(!std::is_reference_v<V>); };
  60. constexpr Vector(Vector<std::remove_reference_t<V>> && a): v(std::move(a.v)), p__(v.begin()) { static_assert(!std::is_reference_v<V>); };
  61. template <class I>
  62. decltype(auto) at(I const & i)
  63. {
  64. RA_CHECK(inside(i[0], ra::size(v)), " i ", i[0], " size ", ra::size(v));
  65. return p__[i[0]];
  66. }
  67. constexpr void adv(rank_t k, dim_t d)
  68. {
  69. // k>0 happens on frame-matching when the axes k>0 can't be unrolled [ra03]
  70. // k==0 && d!=1 happens on turning back at end of ply.
  71. // we need this only on outer products and such, or in FIXME operator<<; which could be fixed I think.
  72. RA_CHECK(d==1 || d<=0, " k ", k, " d ", d, " (Vector)");
  73. p__ += (k==0) * d;
  74. }
  75. constexpr static dim_t stride(int k) { return k==0 ? 1 : 0; }
  76. constexpr static bool keep_stride(dim_t st, int z, int j) { return (z==0) == (j==0); }
  77. constexpr auto flat() const { return p__; }
  78. FOR_EACH(RA_DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  79. };
  80. template <class V> inline constexpr auto vector(V && v) { return Vector<V>(std::forward<V>(v)); }
  81. template <std::random_access_iterator P>
  82. struct Ptr
  83. {
  84. P p__;
  85. constexpr static dim_t size(int k) { RA_CHECK(k==0, " k ", k); return DIM_BAD; }
  86. constexpr static dim_t size_s(int k) { RA_CHECK(k==0, " k ", k); return DIM_BAD; }
  87. constexpr static rank_t rank() { return 1; }
  88. constexpr static rank_t rank_s() { return 1; };
  89. template <class I>
  90. constexpr decltype(auto) at(I && i)
  91. {
  92. return p__[i[0]];
  93. }
  94. constexpr void adv(rank_t k, dim_t d)
  95. {
  96. RA_CHECK(d==1 || d<=0, " k ", k, " d ", d, " (Ptr)");
  97. std::advance(p__, (k==0) * d);
  98. }
  99. constexpr static dim_t stride(int k) { return k==0 ? 1 : 0; }
  100. constexpr static bool keep_stride(dim_t st, int z, int j) { return (z==0) == (j==0); }
  101. constexpr auto flat() const { return p__; }
  102. FOR_EACH(RA_DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  103. };
  104. template <class I> inline auto ptr(I i) { return Ptr<I> { i }; }
  105. // Same as Ptr, just with a size. For stuff like initializer_list that has size but no storage.
  106. template <std::random_access_iterator P>
  107. struct Span
  108. {
  109. P p__;
  110. dim_t n__;
  111. constexpr dim_t size(int k) const { RA_CHECK(k==0, " k ", k); return n__; }
  112. constexpr static dim_t size_s(int k) { RA_CHECK(k==0, " k ", k); return DIM_ANY; }
  113. constexpr static rank_t rank() { return 1; }
  114. constexpr static rank_t rank_s() { return 1; };
  115. template <class I>
  116. decltype(auto) at(I const & i)
  117. {
  118. RA_CHECK(inside(i[0], n__), " i ", i[0], " size ", n__);
  119. return p__[i[0]];
  120. }
  121. constexpr void adv(rank_t k, dim_t d)
  122. {
  123. RA_CHECK(d==1 || d<=0, " k ", k, " d ", d, " (Span)");
  124. std::advance(p__, (k==0) * d);
  125. }
  126. constexpr static dim_t stride(int k) { return k==0 ? 1 : 0; }
  127. constexpr static bool keep_stride(dim_t st, int z, int j) { return (z==0) == (j==0); }
  128. constexpr auto flat() const { return p__; }
  129. FOR_EACH(RA_DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  130. };
  131. template <class I> inline auto ptr(I i, dim_t n) { return Span<I> { i, n }; }
  132. template <int w_, class value_type=ra::dim_t>
  133. struct TensorIndexFlat
  134. {
  135. dim_t i;
  136. constexpr void operator+=(dim_t const s) { i += s; }
  137. constexpr value_type operator*() { return i; }
  138. };
  139. template <int w, class value_type=ra::dim_t>
  140. struct TensorIndex
  141. {
  142. dim_t i = 0;
  143. static_assert(w>=0, "bad TensorIndex");
  144. constexpr static rank_t rank_s() { return w+1; }
  145. constexpr static rank_t rank() { return w+1; }
  146. constexpr static dim_t size_s(int k) { return DIM_BAD; }
  147. constexpr static dim_t size(int k) { return DIM_BAD; } // used in shape checks with dyn rank.
  148. template <class I> constexpr value_type at(I const & ii) const { return value_type(ii[w]); }
  149. constexpr void adv(rank_t k, dim_t d) { RA_CHECK(d<=1, " d ", d); i += (k==w) * d; }
  150. constexpr static dim_t const stride(int k) { return (k==w); }
  151. constexpr static bool keep_stride(dim_t st, int z, int j) { return st*stride(z)==stride(j); }
  152. constexpr auto flat() const { return TensorIndexFlat<w, value_type> {i}; }
  153. };
  154. #define DEF_TENSORINDEX(i) TensorIndex<i> const JOIN(_, i) {};
  155. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  156. #undef DEF_TENSORINDEX
  157. template <class T>
  158. struct IotaFlat
  159. {
  160. T i_;
  161. T const stride_;
  162. T const & operator*() const { return i_; } // TODO if not for this, I could use plain T. Maybe ra::eval_expr...
  163. void operator+=(dim_t d) { i_ += T(d)*stride_; }
  164. };
  165. template <class T_>
  166. struct Iota
  167. {
  168. using T = T_;
  169. dim_t const size_;
  170. T i_;
  171. T const stride_;
  172. constexpr Iota(dim_t size, T org=0, T stride=1): size_(size), i_(org), stride_(stride)
  173. {
  174. RA_CHECK(size>=0, "Iota size ", size);
  175. }
  176. constexpr dim_t size(int k) const { RA_CHECK(k==0, " k ", k); return size_; }
  177. constexpr static dim_t size_s(int k) { RA_CHECK(k==0, " k ", k); return DIM_ANY; }
  178. constexpr rank_t rank() const { return 1; }
  179. constexpr static rank_t rank_s() { return 1; };
  180. template <class I>
  181. constexpr decltype(auto) at(I const & i)
  182. {
  183. return i_ + T(i[0])*stride_;
  184. }
  185. constexpr void adv(rank_t k, dim_t d)
  186. {
  187. i_ += T((k==0) * d) * stride_; // cf Vector::adv
  188. }
  189. constexpr static dim_t stride(rank_t i) { return i==0 ? 1 : 0; }
  190. constexpr static bool keep_stride(dim_t st, int z, int j) { return (z==0) == (j==0); }
  191. constexpr auto flat() const { return IotaFlat<T> { i_, stride_ }; }
  192. decltype(auto) operator+=(T const & b) { i_ += b; return *this; };
  193. decltype(auto) operator-=(T const & b) { i_ -= b; return *this; };
  194. };
  195. template <class O=dim_t, class S=O> inline constexpr auto
  196. iota(dim_t size, O org=0, S stride=1)
  197. {
  198. using T = std::common_type_t<O, S>;
  199. return Iota<T> { size, T(org), T(stride) };
  200. }
  201. template <class I> struct is_beatable_def
  202. {
  203. constexpr static bool value = std::is_integral_v<I>;
  204. constexpr static int skip_src = 1;
  205. constexpr static int skip = 0;
  206. constexpr static bool static_p = value; // can the beating be resolved statically?
  207. };
  208. template <class II> struct is_beatable_def<Iota<II>>
  209. {
  210. constexpr static bool value = std::numeric_limits<II>::is_integer;
  211. constexpr static int skip_src = 1;
  212. constexpr static int skip = 1;
  213. constexpr static bool static_p = false; // it cannot for Iota
  214. };
  215. // FIXME have a 'filler' version (e.g. with default n = -1) or maybe a distinct type.
  216. template <int n> struct is_beatable_def<dots_t<n>>
  217. {
  218. static_assert(n>=0, "bad count for dots_n");
  219. constexpr static bool value = (n>=0);
  220. constexpr static int skip_src = n;
  221. constexpr static int skip = n;
  222. constexpr static bool static_p = true;
  223. };
  224. template <int n> struct is_beatable_def<insert_t<n>>
  225. {
  226. static_assert(n>=0, "bad count for dots_n");
  227. constexpr static bool value = (n>=0);
  228. constexpr static int skip_src = 0;
  229. constexpr static int skip = n;
  230. constexpr static bool static_p = true;
  231. };
  232. template <class I> using is_beatable = is_beatable_def<std::decay_t<I>>;
  233. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  234. template <class T, class K=mp::iota<mp::len<T>>> struct Pick;
  235. template <class FM, class Op, class T, class K=mp::iota<mp::len<T>>> struct Ryn;
  236. template <class Live, class A> struct Reframe;
  237. // --------------
  238. // Coerce potential RaIterator
  239. // --------------
  240. template <class T>
  241. inline constexpr void start(T && t)
  242. {
  243. static_assert(!std::same_as<T, T>, "Type cannot be start()ed.");
  244. }
  245. RA_IS_DEF(is_iota, (std::same_as<A, Iota<typename A::T>>))
  246. RA_IS_DEF(is_ra_scalar, (std::same_as<A, Scalar<typename A::C>>))
  247. RA_IS_DEF(is_ra_vector, (std::same_as<A, Vector<typename A::V>>))
  248. template <class T> requires (is_foreign_vector<T>)
  249. inline constexpr auto
  250. start(T && t)
  251. {
  252. return ra::vector(std::forward<T>(t));
  253. }
  254. template <class T> requires (is_scalar<T>)
  255. inline constexpr auto
  256. start(T && t)
  257. {
  258. return ra::scalar(std::forward<T>(t));
  259. }
  260. // See [ra35] and Vector constructors above. RaIterators need to be restarted in case on every use (eg ra::cross()).
  261. template <class T> requires (is_iterator<T> && !is_ra_scalar<T> && !is_ra_vector<T>)
  262. inline constexpr auto
  263. start(T && t)
  264. {
  265. return std::forward<T>(t);
  266. }
  267. // Copy the iterator but not the data. This follows the behavior of iter(View); Vector is just an interface adaptor [ra35].
  268. template <class T> requires (is_ra_vector<T>)
  269. inline constexpr auto
  270. start(T && t)
  271. {
  272. return vector(t.v);
  273. }
  274. // For Scalar we forward since the iterator is pure interface.
  275. template <class T> requires (is_ra_scalar<T>)
  276. inline constexpr decltype(auto)
  277. start(T && t)
  278. {
  279. return std::forward<T>(t);
  280. }
  281. // Neither cell_iterator nor cell_iterator_small will retain rvalues [ra4].
  282. template <class T> requires (is_slice<T>)
  283. inline constexpr auto
  284. start(T && t)
  285. {
  286. return iter<0>(std::forward<T>(t));
  287. }
  288. template <class T>
  289. inline constexpr auto
  290. start(std::initializer_list<T> v)
  291. {
  292. return ptr(v.begin(), v.size());
  293. }
  294. // forward declare for match.hh; implemented in small.hh.
  295. template <class T> requires (is_builtin_array<T>)
  296. inline constexpr auto
  297. start(T && t);
  298. // FIXME one of these is ET-generic and the other is slice only, so make up your mind.
  299. // FIXME do we really want to drop const? See use in concrete_type.
  300. template <class A> using value_t = std::decay_t<decltype(*(ra::start(std::declval<A>()).flat()))>;
  301. template <class V> inline constexpr dim_t
  302. rank_s()
  303. {
  304. if constexpr (requires { ra_traits<V>::rank_s(); }) {
  305. return ra_traits<V>::rank_s();
  306. } else if constexpr (requires { std::decay_t<V>::rank_s(); }) {
  307. return std::decay_t<V>::rank_s();
  308. } else {
  309. return 0;
  310. }
  311. }
  312. template <class V> inline constexpr rank_t
  313. rank_s(V const &)
  314. {
  315. return rank_s<V>();
  316. }
  317. template <class V_> inline constexpr dim_t
  318. size_s()
  319. {
  320. using V = std::decay_t<V_>;
  321. if constexpr (requires { ra_traits<V>::size_s(); }) {
  322. return ra_traits<V>::size_s();
  323. } else {
  324. if constexpr (V::rank_s()==RANK_ANY) {
  325. return DIM_ANY;
  326. } else {
  327. ra::dim_t s = 1;
  328. for (int i=0; i!=V::rank_s(); ++i) {
  329. if (dim_t ss=V::size_s(i); ss>=0) {
  330. s *= ss;
  331. } else {
  332. return ss; // either DIM_ANY or DIM_BAD
  333. }
  334. }
  335. return s;
  336. }
  337. }
  338. }
  339. template <class V> constexpr dim_t
  340. size_s(V const &)
  341. {
  342. return size_s<V>();
  343. }
  344. template <class V> inline constexpr rank_t
  345. rank(V const & v)
  346. {
  347. if constexpr (requires { ra_traits<V>::rank(v); }) {
  348. return ra_traits<V>::rank(v);
  349. } else {
  350. return v.rank();
  351. }
  352. }
  353. template <class V> inline constexpr auto
  354. size(V const & v)
  355. {
  356. if constexpr (requires { ra_traits<V>::size(v); }) {
  357. return ra_traits<V>::size(v);
  358. } else {
  359. return prod(map([&v](auto && k) { return v.size(k); }, ra::iota(rank(v))));
  360. }
  361. }
  362. // To be used sparingly; prefer implicit matching.
  363. template <class V> inline constexpr decltype(auto)
  364. shape(V const & v)
  365. {
  366. if constexpr (requires { ra_traits<V>::shape(v); }) {
  367. return ra_traits<V>::shape(v);
  368. // FIXME version for static shape. Would prefer to return the map directly (except maybe for static shapes)
  369. } else if constexpr (constexpr rank_t rs=rank_s<V>(); rs>=0) {
  370. return ra::Small<dim_t, rs>(map([&v](int k) { return v.size(k); }, ra::iota(rs)));
  371. } else {
  372. static_assert(RANK_ANY==rs);
  373. rank_t r = v.rank();
  374. std::vector<dim_t> s(r);
  375. for_each([&v, &s](int k) { s[k] = v.size(k); }, ra::iota(r));
  376. return s;
  377. }
  378. }
  379. // To handle arrays of static/dynamic size.
  380. template <class A> void
  381. resize(A & a, dim_t k)
  382. {
  383. if constexpr (DIM_ANY==size_s<A>()) {
  384. a.resize(k);
  385. } else {
  386. RA_CHECK(k==dim_t(a.size_s(0)));
  387. }
  388. }
  389. } // namespace ra