optimize.H 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. // (c) Daniel Llorens - 2015
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file optimize.H
  7. /// @brief Naive optimization pass over ETs.
  8. #pragma once
  9. #include "ra/expr.H"
  10. #include "ra/small.H"
  11. // no real downside to this.
  12. #ifndef RA_OPTIMIZE_IOTA
  13. #define RA_OPTIMIZE_IOTA 1
  14. #endif
  15. // benchmark shows it's not good by default; probably requires optimizing also +=, etc.
  16. #ifndef RA_OPTIMIZE_SMALLVECTOR
  17. #define RA_OPTIMIZE_SMALLVECTOR 0
  18. #endif
  19. namespace ra {
  20. template <class E, int a=0> inline decltype(auto) optimize(E && e) { return std::forward<E>(e); }
  21. // These are named to match & transform Expr<OPNAME, ...> later on, and used by operator+ etc.
  22. #define DEFINE_NAMED_BINARY_OP(OP, OPNAME) \
  23. struct OPNAME \
  24. { \
  25. template <class A, class B> \
  26. decltype(auto) operator()(A && a, B && b) { return std::forward<A>(a) OP std::forward<B>(b); } \
  27. };
  28. DEFINE_NAMED_BINARY_OP(+, plus)
  29. DEFINE_NAMED_BINARY_OP(-, minus)
  30. DEFINE_NAMED_BINARY_OP(*, times)
  31. DEFINE_NAMED_BINARY_OP(/, slash)
  32. #undef DEFINE_NAMED_BINARY_OP
  33. // TODO need something to handle the & variants...
  34. #define ITEM(i) std::get<(i)>(e.t)
  35. #if RA_OPTIMIZE_IOTA==1
  36. #define IS_IOTA(I) std::is_same<std::decay_t<I>, Iota<typename I::T>>::value
  37. // TODO iota(int)*real is not opt to iota(real) since a+a+... != n*a.
  38. template <class X> constexpr bool iota_op = ra::is_zero_or_scalar<X> && std::numeric_limits<value_t<X>>::is_integer;
  39. // --------------
  40. // plus
  41. // --------------
  42. template <class I, class J, std::enable_if_t<IS_IOTA(I) && iota_op<J>, int> =0>
  43. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  44. {
  45. return iota(ITEM(0).size_, ITEM(0).org_+ITEM(1), ITEM(0).stride_);
  46. }
  47. template <class I, class J, std::enable_if_t<iota_op<I> && IS_IOTA(J), int> =0>
  48. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  49. {
  50. return iota(ITEM(1).size_, ITEM(0)+ITEM(1).org_, ITEM(1).stride_);
  51. }
  52. template <class I, class J, std::enable_if_t<IS_IOTA(I) && IS_IOTA(J), int> =0>
  53. inline constexpr auto optimize(Expr<ra::plus, std::tuple<I, J>> && e)
  54. {
  55. assert(ITEM(0).size_==ITEM(1).size_ && "size mismatch");
  56. return iota(ITEM(0).size_, ITEM(0).org_+ITEM(1).org_, ITEM(0).stride_+ITEM(1).stride_);
  57. }
  58. // --------------
  59. // minus
  60. // --------------
  61. template <class I, class J, std::enable_if_t<IS_IOTA(I) && iota_op<J>, int> =0>
  62. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  63. {
  64. return iota(ITEM(0).size_, ITEM(0).org_-ITEM(1), ITEM(0).stride_);
  65. }
  66. template <class I, class J, std::enable_if_t<iota_op<I> && IS_IOTA(J), int> =0>
  67. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  68. {
  69. return iota(ITEM(1).size_, ITEM(0)-ITEM(1).org_, -ITEM(1).stride_);
  70. }
  71. template <class I, class J, std::enable_if_t<IS_IOTA(I) && IS_IOTA(J), int> =0>
  72. inline constexpr auto optimize(Expr<ra::minus, std::tuple<I, J>> && e)
  73. {
  74. assert(ITEM(0).size_==ITEM(1).size_ && "size mismatch");
  75. return iota(ITEM(0).size_, ITEM(0).org_-ITEM(1).org_, ITEM(0).stride_-ITEM(1).stride_);
  76. }
  77. // --------------
  78. // times
  79. // --------------
  80. template <class I, class J, std::enable_if_t<IS_IOTA(I) && iota_op<J>, int> =0>
  81. inline constexpr auto optimize(Expr<ra::times, std::tuple<I, J>> && e)
  82. {
  83. return iota(ITEM(0).size_, ITEM(0).org_*ITEM(1), ITEM(0).stride_*ITEM(1));
  84. }
  85. template <class I, class J, std::enable_if_t<iota_op<I> && IS_IOTA(J), int> =0>
  86. inline constexpr auto optimize(Expr<ra::times, std::tuple<I, J>> && e)
  87. {
  88. return iota(ITEM(1).size_, ITEM(0)*ITEM(1).org_, ITEM(0)*ITEM(1).stride_);
  89. }
  90. #undef IS_IOTA
  91. #endif // RA_OPTIMIZE_IOTA
  92. #if RA_OPTIMIZE_SMALLVECTOR==1
  93. template <class T, int N> using extvector __attribute__((vector_size(N*sizeof(T)))) = T;
  94. #define OPTIMIZE_SMALLVECTOR_OP(OP, NAME) \
  95. inline auto \
  96. optimize(ra::Expr<NAME, \
  97. std::tuple<decltype(start(std::declval<ra::Small<double, 4>>())), \
  98. decltype(start(std::declval<ra::Small<double, 4>>()))>> && e) \
  99. { \
  100. ra::Small<double, 4> val; \
  101. (extvector<double, 4> &)(val) = (extvector<double, 4> &)(*(ITEM(0).p)) OP (extvector<double, 4> &)(*(ITEM(1).p)); \
  102. return val; \
  103. }
  104. OPTIMIZE_SMALLVECTOR_OP(+, ra::plus)
  105. OPTIMIZE_SMALLVECTOR_OP(*, ra::times)
  106. #undef OPTIMIZE_SMALLVECTOR_OP
  107. #endif // RA_OPTIMIZE_SMALLVECTOR
  108. #undef ITEM
  109. } // namespace ra