optimize.hh 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file optimize.hh
  3. /// @brief Naive optimization pass over ETs.
  4. // (c) Daniel Llorens - 2015-2018
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/small.hh"
  11. // no real downside to this.
  12. #ifndef RA_DO_OPT_IOTA
  13. #define RA_DO_OPT_IOTA 1
  14. #endif
  15. // benchmark shows it's bad by default; probably requires optimizing also +=, etc.
  16. #ifndef RA_DO_OPT_SMALLVECTOR
  17. #define RA_DO_OPT_SMALLVECTOR 0
  18. #endif
  19. namespace ra {
  20. template <class E> inline decltype(auto) constexpr optimize(E && e) { return std::forward<E>(e); }
  21. // These are named to match & transform Expr<OPNAME, ...> later on, and used by operator+ etc.
  22. #define DEFINE_NAMED_BINARY_OP(OP, OPNAME) \
  23. struct OPNAME \
  24. { \
  25. template <class A, class B> \
  26. decltype(auto) operator()(A && a, B && b) { return std::forward<A>(a) OP std::forward<B>(b); } \
  27. };
  28. DEFINE_NAMED_BINARY_OP(+, plus)
  29. DEFINE_NAMED_BINARY_OP(-, minus)
  30. DEFINE_NAMED_BINARY_OP(*, times)
  31. DEFINE_NAMED_BINARY_OP(/, slash)
  32. #undef DEFINE_NAMED_BINARY_OP
  33. // TODO need something to handle the & variants...
  34. #define ITEM(i) std::get<(i)>(e.t)
  35. #if RA_DO_OPT_IOTA==1
  36. // TODO iota(int)*real is not opt to iota(real) since a+a+... != n*a.
  37. template <class X> constexpr bool iota_op = ra::is_zero_or_scalar<X> && std::numeric_limits<value_t<X>>::is_integer;
  38. // --------------
  39. // plus
  40. // --------------
  41. template <class I, class J>
  42. requires (is_iota<I> && iota_op<J>)
  43. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  44. {
  45. return iota(ITEM(0).size_, ITEM(0).i_+ITEM(1), ITEM(0).stride_);
  46. }
  47. template <class I, class J>
  48. requires (iota_op<I> && is_iota<J>)
  49. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  50. {
  51. return iota(ITEM(1).size_, ITEM(0)+ITEM(1).i_, ITEM(1).stride_);
  52. }
  53. template <class I, class J>
  54. requires (is_iota<I> && is_iota<J>)
  55. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  56. {
  57. RA_CHECK(ITEM(0).size_==ITEM(1).size_ && "size mismatch");
  58. return iota(ITEM(0).size_, ITEM(0).i_+ITEM(1).i_, ITEM(0).stride_+ITEM(1).stride_);
  59. }
  60. // --------------
  61. // minus
  62. // --------------
  63. template <class I, class J>
  64. requires (is_iota<I> && iota_op<J>)
  65. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  66. {
  67. return iota(ITEM(0).size_, ITEM(0).i_-ITEM(1), ITEM(0).stride_);
  68. }
  69. template <class I, class J>
  70. requires (iota_op<I> && is_iota<J>)
  71. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  72. {
  73. return iota(ITEM(1).size_, ITEM(0)-ITEM(1).i_, -ITEM(1).stride_);
  74. }
  75. template <class I, class J>
  76. requires (is_iota<I> && is_iota<J>)
  77. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  78. {
  79. RA_CHECK(ITEM(0).size_==ITEM(1).size_ && "size mismatch");
  80. return iota(ITEM(0).size_, ITEM(0).i_-ITEM(1).i_, ITEM(0).stride_-ITEM(1).stride_);
  81. }
  82. // --------------
  83. // times
  84. // --------------
  85. template <class I, class J>
  86. requires (is_iota<I> && iota_op<J>)
  87. inline constexpr auto optimize(Expr<ra::times, std::tuple<I, J>> && e)
  88. {
  89. return iota(ITEM(0).size_, ITEM(0).i_*ITEM(1), ITEM(0).stride_*ITEM(1));
  90. }
  91. template <class I, class J>
  92. requires (iota_op<I> && is_iota<J>)
  93. inline constexpr auto optimize(Expr<ra::times, std::tuple<I, J>> && e)
  94. {
  95. return iota(ITEM(1).size_, ITEM(0)*ITEM(1).i_, ITEM(0)*ITEM(1).stride_);
  96. }
  97. #endif // RA_DO_OPT_IOTA
  98. #if RA_DO_OPT_SMALLVECTOR==1
  99. namespace {
  100. #if defined (__clang__)
  101. template <class T, int N> using extvector __attribute__((ext_vector_type(N))) = T;
  102. #else
  103. template <class T, int N> using extvector __attribute__((vector_size(N*sizeof(T)))) = T;
  104. #endif
  105. // FIXME find a way to peel qualifiers from parameter type of start(), to ignore SmallBase<SmallArray> vs SmallBase<SmallView> or const vs nonconst.
  106. template <class A, class T, dim_t N> constexpr bool match_smallvector =
  107. std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::template iterator<0>>
  108. || std::is_same_v<std::decay_t<A>, typename ra::Small<T, N>::template const_iterator<0>>;
  109. static_assert(match_smallvector<ra::cell_iterator_small<ra::SmallBase<ra::SmallView, double, mp::int_list<4>, mp::int_list<1>>, 0>,
  110. double, 4>);
  111. }; //namespace
  112. #define RA_OPT_SMALLVECTOR_OP(OP, NAME, T, N) \
  113. template <class A, class B> \
  114. requires (match_smallvector<A, T, N> && match_smallvector<B, T, N>) \
  115. inline auto \
  116. optimize(ra::Expr<NAME, std::tuple<A, B>> && e) \
  117. { \
  118. alignas (alignof(extvector<T, N>)) ra::Small<T, N> val; \
  119. *(extvector<T, N> *)(&val) = *(extvector<T, N> *)((ITEM(0).c.p)) OP *(extvector<T, N> *)((ITEM(1).c.p)); \
  120. return val; \
  121. }
  122. #define RA_OPT_SMALLVECTOR_OP_FUNS(T, N) \
  123. RA_OPT_SMALLVECTOR_OP(+, ra::plus, T, N) \
  124. RA_OPT_SMALLVECTOR_OP(-, ra::minus, T, N) \
  125. RA_OPT_SMALLVECTOR_OP(/, ra::slash, T, N) \
  126. RA_OPT_SMALLVECTOR_OP(*, ra::times, T, N)
  127. #define RA_OPT_SMALLVECTOR_OP_SIZES(T) \
  128. RA_OPT_SMALLVECTOR_OP_FUNS(T, 2) \
  129. RA_OPT_SMALLVECTOR_OP_FUNS(T, 4) \
  130. RA_OPT_SMALLVECTOR_OP_FUNS(T, 8)
  131. RA_OPT_SMALLVECTOR_OP_SIZES(double)
  132. RA_OPT_SMALLVECTOR_OP_SIZES(float)
  133. #undef RA_OPT_SMALLVECTOR_OP_SIZES
  134. #undef RA_OPT_SMALLVECTOR_OP_FUNS
  135. #undef RA_OPT_SMALLVECTOR_OP_OP
  136. #endif // RA_DO_OPT_SMALLVECTOR
  137. #undef ITEM
  138. } // namespace ra