atom.H 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. // (c) Daniel Llorens - 2011-2016
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file atom.H
  7. /// @brief Terminal nodes for expression templates.
  8. #pragma once
  9. #include "ra/traits.H"
  10. #include "ra/opcheck.H"
  11. #ifdef RA_CHECK_BOUNDS
  12. #define RA_CHECK_BOUNDS_RA_ATOM RA_CHECK_BOUNDS
  13. #else
  14. #ifndef RA_CHECK_BOUNDS_RA_ATOM
  15. #define RA_CHECK_BOUNDS_RA_ATOM 1
  16. #endif
  17. #endif
  18. #if RA_CHECK_BOUNDS_RA_ATOM==0
  19. #define CHECK_BOUNDS( cond )
  20. #else
  21. #define CHECK_BOUNDS( cond ) assert( cond )
  22. #endif
  23. namespace ra {
  24. // value_type may be needed to avoid conversion issues.
  25. template <int w_, class value_type=ra::dim_t>
  26. struct TensorIndex
  27. {
  28. static_assert(w_>=0, "bad TensorIndex");
  29. constexpr static int w = w_;
  30. constexpr static dim_t size(int k) { return DIM_BAD; } // used in shape checks with dyn. rank.
  31. constexpr static dim_t size_s() { return DIM_BAD; }
  32. constexpr static rank_t rank_s() { return w+1; }
  33. constexpr static rank_t rank() { return w+1; }
  34. template <class I> constexpr value_type at(I const & i) const { return value_type(i[w]); }
  35. constexpr void adv(rank_t k, dim_t d) {}
  36. constexpr dim_t stride(int i) const { assert(w<0); return 0; } // used by Expr::stride_t.
  37. constexpr value_type * flat() const { assert(w<0); return nullptr; } // used by Expr::atom_type type signature.
  38. };
  39. #define DEF_TENSORINDEX(i) constexpr TensorIndex<i, int> JOIN(_, i) {};
  40. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  41. #undef DEF_TENSORINDEX
  42. template <class C> struct Scalar;
  43. // Separate from Scalar so that operator+=, etc. has the array meaning there.
  44. // We can reuse the Scalar object b/c operator+= is a no-op.
  45. template <class C>
  46. struct ScalarFlat: public Scalar<C>
  47. {
  48. constexpr void operator+=(dim_t d) const {}
  49. constexpr C const & operator*() const { return this->c; }
  50. constexpr C & operator*() { return this->c; }
  51. };
  52. // Wrap constant for traversal. We still want f(C) to be a specialization in most cases.
  53. template <class C_>
  54. struct Scalar
  55. {
  56. using C = C_;
  57. C c;
  58. // Used in shape checks with dynamic rank. (Never called because rank is 0).
  59. constexpr static dim_t size(int k) { assert(0); return 0; }
  60. constexpr static dim_t size_s() { return 1; }
  61. constexpr static rank_t rank() { return 0; }
  62. constexpr static rank_t rank_s() { return 0; }
  63. using shape_type = std::array<dim_t, 0>;
  64. constexpr static shape_type shape() { return shape_type {}; }
  65. // cf ScalarFlat::operator*
  66. template <class I> constexpr C const & at(I const & i) const { return c; }
  67. template <class I> constexpr C & at(I const & i) { return c; }
  68. constexpr static void adv(rank_t k, dim_t d) {}
  69. constexpr static dim_t stride(int i) { return 0; }
  70. constexpr static bool keep_stride(dim_t step, int z, int j) { return true; }
  71. constexpr decltype(auto) flat() const { return static_cast<ScalarFlat<C> const &>(*this); }
  72. constexpr decltype(auto) flat() { return static_cast<ScalarFlat<C> &>(*this); }
  73. #define DEF_ASSIGNOPS(OP) template <class X> void operator OP(X && x) \
  74. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  75. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  76. #undef DEF_ASSIGNOPS
  77. };
  78. // For the use of std::forward<>, see eg http://www.justsoftwaresolutions.co.uk/cplusplus/rvalue_references_and_perfect_forwarding.html
  79. template <class C> inline constexpr auto
  80. scalar(C && c) { return Scalar<C> { std::forward<C>(c) }; }
  81. // Wrap something with {size, begin} as 1-D vector. Sort of reduced ra_iterator.
  82. // ra::ra_traits_def<V> must be defined with ::size, ::size_s.
  83. // FIXME This can handle temporaries and make_a().begin() can't, look out for that.
  84. // FIXME Do we need this class? holding rvalue is the only thing it does over View, and it doesn't handle rank!=1.
  85. template <class V>
  86. struct Vector
  87. {
  88. using traits = ra_traits<V>;
  89. V v;
  90. decltype(v.begin()) p__;
  91. static_assert(!std::is_reference<decltype(p__)>::value, "bad iterator type");
  92. constexpr dim_t size() const { return traits::size(v); }
  93. constexpr dim_t size(int i) const { CHECK_BOUNDS(i==0); return traits::size(v); }
  94. constexpr static dim_t size_s() { return traits::size_s(); }
  95. constexpr static rank_t rank() { return 1; }
  96. constexpr static rank_t rank_s() { return 1; };
  97. using shape_type = std::array<dim_t, 1>;
  98. constexpr auto shape() const { return shape_type { { dim_t(traits::size(v)) } }; }
  99. // see test-compatibility.C [a1] for forward() here.
  100. Vector(V && v_): v(std::forward<V>(v_)), p__(v.begin()) {}
  101. template <class I>
  102. decltype(auto) at(I const & i)
  103. {
  104. CHECK_BOUNDS(inside(i[0], this->size()));
  105. return p__[i[0]];
  106. }
  107. constexpr void adv(rank_t k, dim_t d)
  108. {
  109. // k>0 happens on frame-matching when the axes k>0 can't be unrolled; see [trc-01] in test-compatibility.C.
  110. // k==0 && d!=1 happens on turning back at end of ply; TODO we need this only on outer products and such.
  111. CHECK_BOUNDS(d==1 || d<0);
  112. p__ += (k==0) * d;
  113. }
  114. constexpr static dim_t stride(int i) { return i==0 ? 1 : 0; }
  115. // reduced from cell_iterator::keep_stride()
  116. constexpr static bool keep_stride(dim_t step, int z, int j) { return (z==0) == (j==0); }
  117. constexpr auto flat() { return p__; }
  118. constexpr auto flat() const { return p__; }
  119. #define DEF_ASSIGNOPS(OP) template <class X> void operator OP(X && x) \
  120. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  121. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  122. #undef DEF_ASSIGNOPS
  123. };
  124. template <class V> inline constexpr auto
  125. vector(V && v) { return Vector<V> { std::forward<V>(v) }; }
  126. // Like Vector, but on the iterator itself, so no size, P only needs to have +=, *, [].
  127. // ra::ra_traits_def<P> doesn't need to be defined.
  128. template <class P>
  129. struct Ptr
  130. {
  131. P p__;
  132. constexpr static dim_t size() { return DIM_BAD; }
  133. constexpr static dim_t size(int i) { CHECK_BOUNDS(i==0); return DIM_BAD; }
  134. constexpr static dim_t size_s() { return DIM_BAD; }
  135. constexpr static rank_t rank() { return 1; }
  136. constexpr static rank_t rank_s() { return 1; };
  137. using shape_type = std::array<dim_t, 1>;
  138. constexpr static auto shape() { return shape_type { { dim_t(DIM_BAD) } }; }
  139. template <class I>
  140. constexpr decltype(auto) at(I const & i)
  141. {
  142. return p__[i[0]];
  143. }
  144. constexpr void adv(rank_t k, dim_t d)
  145. {
  146. CHECK_BOUNDS(d==1 || d<0); // cf Vector::adv
  147. p__ += (k==0) * d;
  148. }
  149. constexpr static dim_t stride(int i) { return i==0 ? 1 : 0; }
  150. // reduced from cell_iterator::keep_stride()
  151. constexpr static bool keep_stride(dim_t step, int z, int j) { return (z==0) == (j==0); }
  152. constexpr auto flat() { return p__; }
  153. constexpr auto flat() const { return p__; }
  154. #define DEF_ASSIGNOPS(OP) template <class X> void operator OP(X && x) \
  155. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  156. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  157. #undef DEF_ASSIGNOPS
  158. };
  159. template <class T> inline constexpr auto ptr(T * p) { return Ptr<T *> { p }; }
  160. template <class T>
  161. struct IotaFlat
  162. {
  163. T i_;
  164. T const stride_;
  165. T const & operator*() const { return i_; } // TODO if not for this, I could use plain T. Maybe ra::eval_expr...
  166. void operator+=(dim_t d) { i_ += T(d)*stride_; }
  167. };
  168. template <class T_>
  169. struct Iota
  170. {
  171. using T = T_;
  172. dim_t const size_;
  173. T const org_;
  174. T const stride_;
  175. T i_;
  176. constexpr Iota(dim_t size, T org=0, T stride=1): size_(size), org_(org), stride_(stride), i_(org)
  177. {
  178. CHECK_BOUNDS(size>=0);
  179. }
  180. constexpr dim_t size() const { return size_; } // this is a Slice function...
  181. constexpr dim_t size(int i) const { CHECK_BOUNDS(i==0); return size_; }
  182. constexpr static dim_t size_s() { return DIM_ANY; }
  183. constexpr rank_t rank() const { return 1; }
  184. constexpr static rank_t rank_s() { return 1; };
  185. using shape_type = std::array<dim_t, 1>;
  186. constexpr auto shape() const { return shape_type { { size_ } }; }
  187. template <class I>
  188. constexpr decltype(auto) at(I const & i)
  189. {
  190. return org_ + T(i[0])*stride_;
  191. }
  192. constexpr void adv(rank_t k, dim_t d)
  193. {
  194. i_ += T((k==0) * d) * stride_; // cf Vector::adv
  195. }
  196. constexpr static dim_t stride(rank_t i) { return i==0 ? 1 : 0; }
  197. // reduced from cell_iterator::keep_stride()
  198. constexpr static bool keep_stride(dim_t step, int z, int j) { return (z==0) == (j==0); }
  199. constexpr auto flat() const { return IotaFlat<T> { i_, stride_ }; }
  200. };
  201. template <class O=dim_t, class S=O> inline constexpr auto
  202. iota(dim_t size, O org=0, S stride=1)
  203. {
  204. using T = std::common_type_t<O, S>;
  205. return Iota<T> { size, T(org), T(stride) };
  206. }
  207. template <class I> struct is_beatable_def
  208. {
  209. constexpr static bool value = std::is_integral<I>::value;
  210. constexpr static int skip_src = 1;
  211. constexpr static int skip = 0;
  212. constexpr static bool static_p = value; // can the beating can be resolved statically?
  213. };
  214. template <class II> struct is_beatable_def<Iota<II>>
  215. {
  216. constexpr static bool value = std::numeric_limits<II>::is_integer;
  217. constexpr static int skip_src = 1;
  218. constexpr static int skip = 1;
  219. constexpr static bool static_p = false; // it cannot for Iota
  220. };
  221. template <int n> struct is_beatable_def<dots_t<n>>
  222. {
  223. static_assert(n>=0, "bad count for dots_n");
  224. constexpr static bool value = (n>=0);
  225. constexpr static int skip_src = n;
  226. constexpr static int skip = n;
  227. constexpr static bool static_p = true;
  228. };
  229. template <int n> struct is_beatable_def<newaxis_t<n>>
  230. {
  231. static_assert(n>=0, "bad count for dots_n");
  232. constexpr static bool value = (n>=0);
  233. constexpr static int skip_src = 0;
  234. constexpr static int skip = n;
  235. constexpr static bool static_p = true;
  236. };
  237. template <class I> using is_beatable = is_beatable_def<std::decay_t<I>>;
  238. template <class X> constexpr bool has_tensorindex_def = false;
  239. template <class T> constexpr bool has_tensorindex = has_tensorindex_def<std::decay_t<T>>;
  240. template <int w, class value_type> constexpr bool has_tensorindex_def<TensorIndex<w, value_type>> = true;
  241. template <class Op, class T, class K=std::make_integer_sequence<int, mp::len<T>> > struct Expr;
  242. template <class T, class K=std::make_integer_sequence<int, mp::len<T>> > struct Pick;
  243. template <class FM, class Op, class T, class K=std::make_integer_sequence<int, mp::len<T>>> struct Ryn;
  244. template <class LiveAxes, int depth, class A> struct ApplyFrames;
  245. template <class Op, class ... Ti, class K>
  246. constexpr bool has_tensorindex_def<Expr<Op, std::tuple<Ti ...>, K>> = (has_tensorindex<Ti> || ...);
  247. template <class ... Ti, class K>
  248. constexpr bool has_tensorindex_def<Pick<std::tuple<Ti ...>, K>> = (has_tensorindex<Ti> || ...);
  249. template <class LiveAxes, int depth, class A>
  250. constexpr bool has_tensorindex_def<ApplyFrames<LiveAxes, depth, A>> = has_tensorindex<A>;
  251. template <class FM, class Op, class ... Ti, class K>
  252. constexpr bool has_tensorindex_def<Ryn<FM, Op, std::tuple<Ti ...>, K>> = (has_tensorindex<Ti> || ...);
  253. } // namespace ra
  254. #undef CHECK_BOUNDS
  255. #undef RA_CHECK_BOUNDS_RA_ATOM