expr.H 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. // (c) Daniel Llorens - 2011-2014, 2016-2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file expr.H
  7. /// @brief Operator nodes for expression templates.
  8. #pragma once
  9. #include "ra/ply.H"
  10. #ifdef RA_CHECK_BOUNDS
  11. #define RA_CHECK_BOUNDS_RA_EXPR RA_CHECK_BOUNDS
  12. #else
  13. #ifndef RA_CHECK_BOUNDS_RA_EXPR
  14. #define RA_CHECK_BOUNDS_RA_EXPR 1
  15. #endif
  16. #endif
  17. #if RA_CHECK_BOUNDS_RA_EXPR==0
  18. #define CHECK_BOUNDS( cond )
  19. #else
  20. #define CHECK_BOUNDS( cond ) assert( cond )
  21. #endif
  22. namespace ra {
  23. // Manipulate ET through flat (raw pointer-like) iterators P ...
  24. template <class Op, class T, class I=std::make_integer_sequence<int, mp::len<T>>>
  25. struct Flat;
  26. template <class Op, class ... P, int ... I>
  27. struct Flat<Op, std::tuple<P ...>, std::integer_sequence<int, I ...>>
  28. {
  29. Op & op;
  30. std::tuple<P ...> t;
  31. template <class S> void operator+=(S const & s) { ((std::get<I>(t) += std::get<I>(s)), ...); }
  32. decltype(auto) operator*() { return op(*std::get<I>(t) ...); } // TODO std::apply(op, t)
  33. };
  34. template <class Op, class ... P> inline constexpr
  35. auto flat(Op & op, P && ... p)
  36. {
  37. return Flat<Op, std::tuple<P ...>> { op, std::tuple<P ...> { std::forward<P>(p) ... } };
  38. }
  39. // forward decl in atom.H
  40. // TODO others:
  41. // * 'static': Like expr, but the operator is compile time (e.g. ::apply()).
  42. // * 'dynamic': Like expr, but driver is selected at run time (e.g. if args are RANK_ANY).
  43. template <class Op, class ... P, int ... I>
  44. struct Expr<Op, std::tuple<P ...>, std::integer_sequence<int, I ...>>
  45. {
  46. // A-th argument decides rank and shape.
  47. constexpr static int A = largest_rank<P ...>::value;
  48. using PA = std::decay_t<mp::Ref_<std::tuple<P ...>, A>>;
  49. using NotA = mp::ComplementList_<mp::int_list<A>, std::tuple<mp::int_t<I> ...>>;
  50. Op op;
  51. std::tuple<P ...> t;
  52. // If driver is RANK_ANY, driver selection should wait to run time, unless we can tell that RANK_ANY would be selected anyway.
  53. constexpr static bool VALID_DRIVER = PA::size_s()!=DIM_BAD ; //&& (PA::rank_s()!=RANK_ANY || sizeof...(P)==1);
  54. template <int iarg>
  55. std::enable_if_t<(iarg==mp::len<NotA>), bool>
  56. check(int const driver_rank) const { return true; }
  57. template <int iarg>
  58. std::enable_if_t<(iarg<mp::len<NotA>), bool>
  59. check(int const driver_rank) const
  60. {
  61. rank_t ranki = std::get<mp::Ref_<NotA, iarg>::value>(t).rank();
  62. // Provide safety where RANK_ANY was selected as driver in a leap of faith. TODO Dynamic driver selection.
  63. assert(ranki<=driver_rank && "driver not max rank (could be RANK_ANY)");
  64. for (int k=0; k!=ranki; ++k) {
  65. dim_t sk0 = std::get<A>(t).size(k);
  66. if (sk0!=DIM_BAD) { // may be == in subexpressions
  67. dim_t sk = std::get<mp::Ref_<NotA, iarg>::value>(t).size(k);
  68. assert((sk==sk0 || sk==DIM_BAD) && "mismatched dimensions");
  69. }
  70. }
  71. return check<iarg+1>(driver_rank);
  72. }
  73. // see test-compatibility.C [a1] for forward() here.
  74. constexpr Expr(Op op_, P ... p_): op(std::forward<Op>(op_)), t(std::forward<P>(p_) ...)
  75. {
  76. // TODO Try to static_assert. E.g., size_s() vs size_s() can static_assert if we try real3==real2.
  77. // TODO Should check only the driver: do this on ply.
  78. CHECK_BOUNDS(check<0>(rank()));
  79. }
  80. template <class J>
  81. constexpr decltype(auto) at(J const & i)
  82. {
  83. return op(std::get<I>(t).at(i) ...);
  84. }
  85. constexpr void adv(rank_t k, dim_t d)
  86. {
  87. (std::get<I>(t).adv(k, d), ...);
  88. }
  89. constexpr bool keep_stride(dim_t step, int z, int j) const
  90. {
  91. return (std::get<I>(t).keep_stride(step, z, j) && ...);
  92. }
  93. constexpr auto stride(int i) const
  94. {
  95. return std::make_tuple(std::get<I>(t).stride(i) ...);
  96. }
  97. constexpr decltype(auto) flat()
  98. {
  99. return ra::flat(op, std::get<I>(t).flat() ...);
  100. }
  101. constexpr decltype(auto) flat() const { return flat(); }
  102. // there's one size (by A), but each arg has its own strides.
  103. // Note: do not require driver. This is needed by check for all leaves.
  104. constexpr dim_t size(int i) const { return std::get<A>(t).size(i); }
  105. constexpr static dim_t size_s() { return PA::size_s(); }
  106. constexpr static dim_t size_s(int i) { return PA::size_s(i); }
  107. constexpr rank_t rank() const { return std::get<A>(t).rank(); }
  108. constexpr static rank_t rank_s() { return PA::rank_s(); }
  109. constexpr decltype(auto) shape() const
  110. {
  111. static_assert(VALID_DRIVER, "can't drive this xpr");
  112. return std::get<A>(t).shape();
  113. }
  114. // needed for xpr with rank_s()==RANK_ANY, which don't decay to scalar when used as operator arguments.
  115. operator decltype(*(ra::flat(op, std::get<I>(t).flat() ...)))()
  116. {
  117. static_assert(rank_s()==0 || rank_s()==RANK_ANY || (rank_s()==1 && size_s()==1), // for coord types
  118. "bad rank in conversion to scalar");
  119. assert(rank()==0 || (rank_s()==1 && size_s()==1)); // for coord types; so fixed only
  120. return *flat();
  121. }
  122. // forward to make sure value y is not misused as ref. Cf. test-ra-8.C
  123. #define DEF_ASSIGNOPS(OP) template <class X> void operator OP(X && x) \
  124. { for_each([](auto && y, auto && x) { std::forward<decltype(y)>(y) OP x; }, *this, x); }
  125. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  126. #undef DEF_ASSIGNOPS
  127. };
  128. template <class Op, class ... P> inline constexpr auto
  129. expr(Op && op, P && ... p)
  130. {
  131. return Expr<Op, std::tuple<P ...>> { std::forward<Op>(op), std::forward<P>(p) ... };
  132. }
  133. // Wrappers over expr & ply.
  134. template <class F, class ... A> inline constexpr auto
  135. map(F && f, A && ... a)
  136. {
  137. return expr(std::forward<F>(f), start(std::forward<A>(a)) ...);
  138. }
  139. template <class F, class ... A> inline constexpr void
  140. for_each(F && f, A && ... a)
  141. {
  142. ply(expr(std::forward<F>(f), start(std::forward<A>(a)) ...));
  143. }
  144. } // namespace ra
  145. #undef CHECK_BOUNDS
  146. #undef RA_CHECK_BOUNDS_RA_EXPR