pick.H 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. // (c) Daniel Llorens - 2016-2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file expr.H
  7. /// @brief Expression template that picks one of several arguments.
  8. // This class is needed because Expr evaluates its arguments before calling its
  9. // operator. Note that pick() is a normal function, so its arguments will always
  10. // be evaluated; it is the individual array elements that will not.
  11. #pragma once
  12. #include "ra/ply.H"
  13. #ifdef RA_CHECK_BOUNDS
  14. #define RA_CHECK_BOUNDS_RA_PICK RA_CHECK_BOUNDS
  15. #else
  16. #ifndef RA_CHECK_BOUNDS_RA_PICK
  17. #define RA_CHECK_BOUNDS_RA_PICK 1
  18. #endif
  19. #endif
  20. #if RA_CHECK_BOUNDS_RA_PICK==0
  21. #define CHECK_BOUNDS( cond )
  22. #else
  23. #define CHECK_BOUNDS( cond ) assert( cond )
  24. #endif
  25. namespace ra {
  26. // -----------------
  27. // return type of pick expression, otherwise compiler complains of ambiguity.
  28. // TODO & is crude, maybe is_assignable?
  29. // -----------------
  30. template <class T, class Enable=void>
  31. struct pick_type
  32. {
  33. using type = mp::Apply_<std::common_type, T>;
  34. };
  35. // lvalue
  36. template <class P0, class ... P>
  37. struct pick_type<std::tuple<P0 &, P ...>,
  38. std::enable_if_t<!std::is_const<P0>::value
  39. && (std::is_same<P0 &, P>::value && ...)>>
  40. {
  41. using type = P0 &;
  42. };
  43. // const lvalue
  44. template <class P0, class ... P>
  45. struct pick_type<std::tuple<P0 &, P & ...>,
  46. std::enable_if_t<(std::is_same<std::decay_t<P0>, std::decay_t<P>>::value && ...)
  47. && (std::is_const<P0>::value || (std::is_const<P>::value || ...))>>
  48. {
  49. using type = P0 const &;
  50. };
  51. // -----------------
  52. // runtime to compile time conversion for Pick::at() and PickFlat::operator*()
  53. // -----------------
  54. template <class T, class J> struct pick_at_type;
  55. template <class ... P, class J> struct pick_at_type<std::tuple<P ...>, J>
  56. {
  57. using type = typename pick_type<std::tuple<decltype(std::declval<P>().at(std::declval<J>())) ...>>::type;
  58. };
  59. template <class T, class J> using pick_at_t = typename pick_at_type<mp::Drop1_<std::decay_t<T>>, J>::type;
  60. template <size_t I, class T, class J> inline constexpr
  61. auto pick_at(size_t p0, T && t, J const & j)
  62. -> std::enable_if_t<(I+2==std::tuple_size<std::decay_t<T>>::value), pick_at_t<T, J>>
  63. {
  64. CHECK_BOUNDS(p0==I);
  65. return std::get<I+1>(t).at(j);
  66. }
  67. template <size_t I, class T, class J> inline constexpr
  68. auto pick_at(size_t p0, T && t, J const & j)
  69. -> std::enable_if_t<(I+2<std::tuple_size<std::decay_t<T>>::value), pick_at_t<T, J>>
  70. {
  71. if (p0==I) {
  72. return std::get<I+1>(t).at(j);
  73. } else {
  74. return pick_at<I+1>(p0, t, j);
  75. }
  76. }
  77. template <class T> struct pick_star_type;
  78. template <class ... P> struct pick_star_type<std::tuple<P ...>>
  79. {
  80. using type = typename pick_type<std::tuple<decltype(*std::declval<P>()) ...>>::type;
  81. };
  82. template <class T> using pick_star_t = typename pick_star_type<mp::Drop1_<std::decay_t<T>>>::type;
  83. template <size_t I, class T> inline constexpr
  84. auto pick_star(size_t p0, T && t)
  85. -> std::enable_if_t<(I+2==std::tuple_size<std::decay_t<T>>::value), pick_star_t<T>>
  86. {
  87. CHECK_BOUNDS(p0==I);
  88. return *(std::get<I+1>(t));
  89. }
  90. template <size_t I, class T> inline constexpr
  91. auto pick_star(size_t p0, T && t)
  92. -> std::enable_if_t<(I+2<std::tuple_size<std::decay_t<T>>::value), pick_star_t<T>>
  93. {
  94. if (p0==I) {
  95. return *(std::get<I+1>(t));
  96. } else {
  97. return pick_star<I+1>(p0, t);
  98. }
  99. }
  100. // -----------------
  101. // follows closely Flat, Expr
  102. // -----------------
  103. // Manipulate ET through flat (raw pointer-like) iterators P ...
  104. template <class T, class I=std::make_integer_sequence<int, mp::len<T>>>
  105. struct PickFlat;
  106. template <class P0, class ... P, int ... I>
  107. struct PickFlat<std::tuple<P0, P ...>, std::integer_sequence<int, I ...>>
  108. {
  109. std::tuple<P0, P ...> t;
  110. template <class S> void operator+=(S const & s) { ((std::get<I>(t) += std::get<I>(s)), ...); }
  111. decltype(auto) operator*() { return pick_star<0>(*std::get<0>(t), t); }
  112. };
  113. template <class P0, class ... P> inline constexpr
  114. auto pick_flat(P0 && p0, P && ... p)
  115. {
  116. return PickFlat<std::tuple<P0, P ...>> { std::tuple<P0, P ...> { std::forward<P0>(p0), std::forward<P>(p) ... } };
  117. }
  118. // forward decl in atom.H
  119. template <class P0, class ... P, int ... I>
  120. struct Pick<std::tuple<P0, P ...>, std::integer_sequence<int, I ...>>
  121. {
  122. // A-th argument decides rank and shape.
  123. constexpr static int A = largest_rank<P0, P ...>::value;
  124. using PA = std::decay_t<mp::Ref_<std::tuple<P0, P ...>, A>>;
  125. using NotA = mp::ComplementList_<mp::int_list<A>, std::tuple<mp::int_t<I> ...>>;
  126. std::tuple<P0, P ...> t;
  127. // If driver is RANK_ANY, driver selection should wait til run time, unless we can tell that RANK_ANY would be selected anyway.
  128. constexpr static bool VALID_DRIVER = PA::size_s()!=DIM_BAD ; //&& (PA::rank_s()!=RANK_ANY || sizeof...(P)==1);
  129. template <int iarg>
  130. std::enable_if_t<(iarg==mp::len<NotA>), bool>
  131. check(int const driver_rank) const { return true; }
  132. template <int iarg>
  133. std::enable_if_t<(iarg<mp::len<NotA>), bool>
  134. check(int const driver_rank) const
  135. {
  136. rank_t ranki = std::get<mp::Ref_<NotA, iarg>::value>(t).rank();
  137. // Provide safety where RANK_ANY was selected as driver in a leap of faith. TODO Dynamic driver selection.
  138. assert(ranki<=driver_rank && "driver not max rank (could be RANK_ANY)");
  139. for (int k=0; k!=ranki; ++k) {
  140. dim_t sk0 = std::get<A>(t).size(k);
  141. if (sk0!=DIM_BAD) { // may be == in subexpressions
  142. dim_t sk = std::get<mp::Ref_<NotA, iarg>::value>(t).size(k);
  143. assert((sk==sk0 || sk==DIM_BAD) && "mismatched dimensions");
  144. }
  145. }
  146. return check<iarg+1>(driver_rank);
  147. }
  148. // see test-compatibility.C [a1] for forward() here.
  149. constexpr Pick(P0 p0_, P ... p_): t(std::forward<P0>(p0_), std::forward<P>(p_) ...)
  150. {
  151. // TODO Try to static_assert. E.g., size_s() vs size_s() can static_assert if we try real3==real2.
  152. // TODO Should check only the driver: do this on ply.
  153. CHECK_BOUNDS(check<0>(rank()));
  154. }
  155. template <class J>
  156. constexpr decltype(auto) at(J const & j)
  157. {
  158. return pick_at<0>(std::get<0>(t).at(j), t, j);
  159. }
  160. constexpr void adv(rank_t k, dim_t d)
  161. {
  162. (std::get<I>(t).adv(k, d), ...);
  163. }
  164. constexpr bool keep_stride(dim_t step, int z, int j) const
  165. {
  166. return (std::get<I>(t).keep_stride(step, z, j) && ...);
  167. }
  168. constexpr auto stride(int i) const
  169. {
  170. return std::make_tuple(std::get<I>(t).stride(i) ...);
  171. }
  172. constexpr auto flat()
  173. {
  174. return pick_flat(std::get<I>(t).flat() ...);
  175. }
  176. constexpr auto flat() const { return flat(); }
  177. // there's one size (by A), but each arg has its own strides.
  178. // Note: do not require driver. This is needed by check for all leaves.
  179. constexpr dim_t size(int i) const { return std::get<A>(t).size(i); }
  180. constexpr static dim_t size_s(int i) { return PA::size_s(i); }
  181. constexpr static dim_t size_s() { return PA::size_s(); }
  182. constexpr rank_t rank() const { return std::get<A>(t).rank(); }
  183. constexpr static rank_t rank_s() { return PA::rank_s(); }
  184. constexpr decltype(auto) shape() const
  185. {
  186. static_assert(VALID_DRIVER, "can't drive this xpr");
  187. return std::get<A>(t).shape();
  188. }
  189. // needed for xpr with rank_s()==RANK_ANY, which don't decay to scalar when used as operator arguments.
  190. operator decltype(*(pick_flat(std::get<I>(t).flat() ...)))()
  191. {
  192. static_assert(rank_s()==0 || rank_s()==RANK_ANY || (rank_s()==1 && size_s()==1), // for coord types
  193. "bad rank in conversion to scalar");
  194. assert(rank()==0 || (rank_s()==1 && size_s()==1)); // for coord types; so fixed only
  195. return *flat();
  196. }
  197. // forward to make sure value y is not misused as ref. Cf. test-ra-8.C
  198. #define DEF_ASSIGNOPS(OP) template <class X> void operator OP(X && x) \
  199. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  200. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  201. #undef DEF_ASSIGNOPS
  202. };
  203. #undef DEF_ASSIGNOPS_
  204. template <class P0, class ... P> inline constexpr auto
  205. pick_in(P0 && p0, P && ... p)
  206. {
  207. return Pick<std::tuple<P0, P ...>> { std::forward<P0>(p0), std::forward<P>(p) ... };
  208. }
  209. template <class P0, class ... P> inline constexpr auto
  210. pick(P0 && p0, P && ... p)
  211. {
  212. return pick_in(start(std::forward<P0>(p0)), start(std::forward<P>(p)) ...);
  213. }
  214. } // namespace ra
  215. #undef CHECK_BOUNDS
  216. #undef RA_CHECK_BOUNDS_RA_PICK