wrank.hh 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file wrank.hh
  3. /// @brief Rank conjunction for expression templates.
  4. // (c) Daniel Llorens - 2013-2017, 2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/match.hh"
  11. namespace ra {
  12. // ---------------------------
  13. // reframe - a variant of transpose that works on any array iterator.
  14. // As in transpose, one names the destination axis for each original axis.
  15. // However, unlike general transpose, axes may not be repeated.
  16. // The main application is the rank conjunction below.
  17. // ---------------------------
  18. template <class T>
  19. struct zerostride
  20. {
  21. constexpr static T f() { return T(0); }
  22. };
  23. template <class ... T>
  24. struct zerostride<std::tuple<T ...>>
  25. {
  26. constexpr static std::tuple<T ...> f() { return std::make_tuple(zerostride<T>::f() ...); }
  27. };
  28. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  29. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  30. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  31. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  32. // TODO invalid for RANK_ANY (since Dest is compile time). [ra07]
  33. template <class Dest, class A>
  34. struct Reframe
  35. {
  36. A a;
  37. constexpr static int orig(int k) { return mp::int_list_index<Dest>(k); }
  38. constexpr static rank_t rank_s() { return 1+mp::fold<mp::max, mp::int_t<-1>, Dest>::value; }
  39. constexpr static rank_t rank() { return rank_s(); }
  40. constexpr static dim_t size_s(int k)
  41. {
  42. int l = orig(k);
  43. return l>=0 ? std::decay_t<A>::size_s(l) : DIM_BAD;
  44. }
  45. constexpr dim_t size(int k) const
  46. {
  47. int l = orig(k);
  48. return l>=0 ? a.size(l) : DIM_BAD;
  49. }
  50. constexpr void adv(rank_t k, dim_t d)
  51. {
  52. if (int l = orig(k); l>=0) {
  53. a.adv(l, d);
  54. }
  55. }
  56. constexpr auto stride(int k) const
  57. {
  58. int l = orig(k);
  59. return l>=0 ? a.stride(l) : zerostride<decltype(a.stride(l))>::f();
  60. }
  61. constexpr bool keep_stride(dim_t st, int z, int j) const
  62. {
  63. int wz = orig(z);
  64. int wj = orig(j);
  65. return wz>=0 && wj>=0 && a.keep_stride(st, wz, wj);
  66. }
  67. template <class I> constexpr decltype(auto) at(I const & i)
  68. {
  69. return a.at(mp::map_indices<std::array<dim_t, mp::len<Dest>>, Dest>(i));
  70. }
  71. constexpr decltype(auto) flat() { return a.flat(); }
  72. };
  73. // Optimize no-op case.
  74. // TODO If A is cell_iterator, etc. beat Dest directly on that... same for an eventual transpose_expr<>.
  75. template <class Dest, class A> decltype(auto)
  76. reframe(A && a)
  77. {
  78. if constexpr (std::is_same_v<Dest, mp::iota<1+mp::fold<mp::max, mp::int_t<-1>, Dest>::value>>) {
  79. return std::forward<A>(a);
  80. } else {
  81. return Reframe<Dest, A> { std::forward<A>(a) };
  82. }
  83. }
  84. // ---------------------------
  85. // verbs and rank conjunction
  86. // ---------------------------
  87. template <class cranks_, class Op_>
  88. struct Verb
  89. {
  90. using cranks = cranks_;
  91. using Op = Op_;
  92. Op op;
  93. };
  94. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  95. template <class cranks, class Op> inline constexpr auto
  96. wrank(cranks cranks_, Op && op)
  97. {
  98. return Verb<cranks, Op> { std::forward<Op>(op) };
  99. }
  100. template <rank_t ... crank, class Op> inline constexpr auto
  101. wrank(Op && op)
  102. {
  103. return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
  104. }
  105. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  106. struct Framematch_def;
  107. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  108. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  109. template <class A, class B>
  110. struct max_i
  111. {
  112. constexpr static int value = gt_rank(A::value, B::value) ? 0 : 1; // 0 if ra wins, else 1
  113. };
  114. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  115. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  116. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  117. {
  118. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "bad args");
  119. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  120. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  121. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  122. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  123. using R = typename FM::R;
  124. template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); } // cf [ra31]
  125. };
  126. // Terminal case where V doesn't have rank (is a raw op()).
  127. template <class V, class ... Ti, class ... Ri, rank_t skip>
  128. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  129. {
  130. static_assert(sizeof...(Ti)==sizeof...(Ri), "bad args");
  131. // TODO -crank::value when the actual verb rank is used (e.g. to use cell_iterator<A, that_rank> instead of just begin()).
  132. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  133. template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
  134. };
  135. } // namespace ra