operators.H 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. // (c) Daniel Llorens - 2014-2017
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file operators.H
  7. /// @brief Sugar for ra:: expression templates.
  8. #pragma once
  9. #include "ra/wrank.H"
  10. #include "ra/pick.H"
  11. // FIXME Dependence on specific ra:: types should maybe be elsewhere.
  12. #include "ra/small.H"
  13. #include "ra/view-ops.H"
  14. #include "ra/concrete.H"
  15. #include "ra/complex.H"
  16. #include "ra/wedge.H"
  17. #include "ra/global.H"
  18. #ifdef RA_CHECK_BOUNDS
  19. #define RA_CHECK_BOUNDS_RA_OPERATORS RA_CHECK_BOUNDS
  20. #else
  21. #ifndef RA_CHECK_BOUNDS_RA_OPERATORS
  22. #define RA_CHECK_BOUNDS_RA_OPERATORS 1
  23. #endif
  24. #endif
  25. #if RA_CHECK_BOUNDS_RA_OPERATORS==0
  26. #define CHECK_BOUNDS( cond )
  27. #else
  28. #define CHECK_BOUNDS( cond ) assert( cond )
  29. #endif
  30. #include "ra/optimize.H"
  31. #ifndef RA_OPTIMIZE
  32. #define RA_OPTIMIZE 1 // enabled by default
  33. #endif
  34. #if RA_OPTIMIZE==1
  35. #define OPTIMIZE optimize
  36. #else
  37. #define OPTIMIZE
  38. #endif
  39. #define START(a) *(ra::start(a).flat())
  40. namespace ra {
  41. template <int ... Iarg, class A>
  42. inline decltype(auto) transpose(mp::int_list<Iarg ...>, A && a)
  43. {
  44. return transpose<Iarg ...>(std::forward<A>(a));
  45. }
  46. // I considered three options for lookup.
  47. // 1. define these in a class that ArrayIterator or Container or Slice types derive from. This was done for an old library I had (vector-ops.H). It results in the narrowest scope, but since those types are used in the definition (ra::Expr is an ArrayIterator), it requires lots of forwarding and traits:: .
  48. // 2. raw ADL doesn't work because some ra:: types use ! != etc for different things (e.g. Flat). Possible solution: don't ever use + != == for Flat.
  49. // 3. enable_if'd ADL is what you see here.
  50. // --------------------------------
  51. // Array versions of operators and functions
  52. // --------------------------------
  53. // We need the zero/scalar specializations because the scalar/scalar operators
  54. // maybe be templated (e.g. complex<>), so they won't be found when an implicit
  55. // conversion from zero->scalar is also needed. That is, without those
  56. // specializations, ra::View<complex, 0> * complex will fail.
  57. // These depend on OPNAME defined in optimize.H and used there to match ET patterns.
  58. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  59. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
  60. inline auto operator OP(A && a, B && b) \
  61. { \
  62. return OPTIMIZE(map(OPNAME(), std::forward<A>(a), std::forward<B>(b))); \
  63. } \
  64. template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
  65. inline auto operator OP(A && a, B && b) \
  66. { \
  67. return START(a) OP START(b); \
  68. }
  69. DEF_NAMED_BINARY_OP(+, plus)
  70. DEF_NAMED_BINARY_OP(-, minus)
  71. DEF_NAMED_BINARY_OP(*, times)
  72. DEF_NAMED_BINARY_OP(/, slash)
  73. #undef DEF_NAMED_BINARY_OP
  74. #define DEF_BINARY_OP(OP) \
  75. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
  76. inline auto operator OP(A && a, B && b) \
  77. { \
  78. return map([](auto && a, auto && b) { return a OP b; }, \
  79. std::forward<A>(a), std::forward<B>(b)); \
  80. } \
  81. template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
  82. inline auto operator OP(A && a, B && b) \
  83. { \
  84. return START(a) OP START(b); \
  85. }
  86. FOR_EACH(DEF_BINARY_OP, &&, ||, >, <, >=, <=, ==, !=)
  87. #undef DEF_BINARY_OP
  88. #define DEF_UNARY_OP(OP) \
  89. template <class A, std::enable_if_t<is_ra_pos_rank<A>, int> = 0> \
  90. inline auto operator OP(A && a) \
  91. { \
  92. return map([](auto && a) { return OP a; }, std::forward<A>(a)); \
  93. }
  94. FOR_EACH(DEF_UNARY_OP, !, +, -) // TODO Make + into nop.
  95. #undef DEF_UNARY_OP
  96. // When OP(a) isn't found from ra::, the deduction from rank(0) -> scalar doesn't work.
  97. // TODO Cf [ref:examples/useret.C:0].
  98. #define DEF_NAME_OP(OP) \
  99. using ::OP; \
  100. template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
  101. inline auto OP(A && ... a) \
  102. { \
  103. return map([](auto && ... a) { return OP(a ...); }, std::forward<A>(a) ...); \
  104. } \
  105. template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
  106. inline auto OP(A && ... a) \
  107. { \
  108. return OP(START(a) ...); \
  109. }
  110. FOR_EACH(DEF_NAME_OP, rel_error, pow, xI, conj, sqr, sqrm, sqrt, cos, sin)
  111. FOR_EACH(DEF_NAME_OP, exp, expm1, log, log1p, log10, isfinite, isnan, isinf)
  112. FOR_EACH(DEF_NAME_OP, max, min, abs, odd, asin, acos, atan, atan2)
  113. FOR_EACH(DEF_NAME_OP, cosh, sinh, tanh, clamp)
  114. #undef DEF_NAME_OP
  115. #define DEF_NAME_OP(OP) \
  116. using ::OP; \
  117. template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
  118. inline auto OP(A && ... a) \
  119. { \
  120. return map([](auto && ... a) -> decltype(auto) { return OP(a ...); }, std::forward<A>(a) ...); \
  121. } \
  122. template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
  123. inline decltype(auto) OP(A && ... a) \
  124. { \
  125. return OP(START(a) ...); \
  126. }
  127. FOR_EACH(DEF_NAME_OP, real_part, imag_part)
  128. #undef DEF_NAME_OP
  129. template <class T, class A>
  130. inline auto cast(A && a)
  131. {
  132. return map([](auto && a) { return T(a); }, std::forward<A>(a));
  133. }
  134. // TODO could be useful to deduce T as tuple of value_types (&).
  135. template <class T, class ... A>
  136. inline auto pack(A && ... a)
  137. {
  138. return map([](auto && ... a) { return T { a ... }; }, std::forward<A>(a) ...);
  139. }
  140. // FIXME Inelegant story wrt plain array / nested array :-/
  141. template <class A, class I>
  142. inline auto at(A && a, I && i)
  143. {
  144. return map([&a](auto && i) -> decltype(auto) { return a.at(i); }, i);
  145. }
  146. template <class W, class T, class F, std::enable_if_t<is_ra_pos_rank<W> || is_ra_pos_rank<T> || is_ra_pos_rank<F>, int> = 0>
  147. inline auto where(W && w, T && t, F && f)
  148. {
  149. return pick(cast<bool>(start(std::forward<W>(w))), start(std::forward<F>(f)), start(std::forward<T>(t)));
  150. }
  151. // --------------------------------
  152. // Some whole-array reductions.
  153. // TODO First rank reductions? Variable rank reductions?
  154. // --------------------------------
  155. template <class A> inline bool
  156. any(A && a)
  157. {
  158. return early(map([](bool x) { return std::make_tuple(x, x); }, std::forward<A>(a)), false);
  159. }
  160. template <class A> inline bool
  161. every(A && a)
  162. {
  163. return early(map([](bool x) { return std::make_tuple(!x, x); }, std::forward<A>(a)), true);
  164. }
  165. // FIXME variable rank? see J 'index of' (x i. y), etc.
  166. template <class A>
  167. inline auto index(A && a)
  168. {
  169. return early(map([](auto && a, auto && i) { return std::make_tuple(bool(a), i); },
  170. std::forward<A>(a), ra::iota(start(a).size(0))),
  171. ra::dim_t(-1));
  172. }
  173. // [ma108]
  174. template <class A, class B>
  175. inline bool lexicographical_compare(A && a, B && b)
  176. {
  177. return early(map([](auto && a, auto && b)
  178. { return a==b ? std::make_tuple(false, true) : std::make_tuple(true, a<b); },
  179. a, b),
  180. false);
  181. }
  182. using std::min;
  183. template <class A>
  184. inline auto amin(A && a)
  185. {
  186. using T = value_t<A>;
  187. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  188. for_each([&c](auto && a) { c = min(c, a); }, a);
  189. return c;
  190. }
  191. using std::max;
  192. template <class A>
  193. inline auto amax(A && a)
  194. {
  195. using T = value_t<A>;
  196. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  197. for_each([&c](auto && a) { c = max(c, a); }, a);
  198. return c;
  199. }
  200. template <class A>
  201. inline auto sum(A && a)
  202. {
  203. value_t<A> c {};
  204. for_each([&c](auto && a) { c += a; }, a);
  205. return c;
  206. }
  207. template <class A>
  208. inline auto prod(A && a)
  209. {
  210. value_t<A> c(1.);
  211. for_each([&c](auto && a) { c *= a; }, a);
  212. return c;
  213. }
  214. template <class A> inline auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
  215. template <class A> inline auto norm2(A && a) { return sqrt(reduce_sqrm(a)); }
  216. template <class A, class B>
  217. inline auto dot(A && a, B && b)
  218. {
  219. std::decay_t<decltype(START(a) * START(b))> c(0.);
  220. for_each([&c](auto && a, auto && b) { c = fma(a, b, c); }, a, b);
  221. return c;
  222. }
  223. template <class A, class B>
  224. inline auto cdot(A && a, B && b)
  225. {
  226. std::decay_t<decltype(conj(START(a)) * START(b))> c(0.);
  227. for_each([&c](auto && a, auto && b) { c = fma_conj(a, b, c); }, a, b);
  228. return c;
  229. }
  230. // --------------------
  231. // Wedge product
  232. // TODO Handle the simplifications dot_plus, yields_scalar, etc. just as vec::wedge does.
  233. // --------------------
  234. template <class A>
  235. struct torank1
  236. {
  237. using type = mp::If_<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  238. };
  239. template <class Wedge, class Va, class Vb, class Enable=void>
  240. struct fromrank1
  241. {
  242. using valtype = typename Wedge::template valtype<Va, Vb>;
  243. using type = mp::If_<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr> >;
  244. };
  245. #define DECL_WEDGE(condition) \
  246. template <int D, int Oa, int Ob, class Va, class Vb, \
  247. std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
  248. decltype(auto) wedge(Va const & a, Vb const & b)
  249. DECL_WEDGE(general_case)
  250. {
  251. Small<std::decay_t<decltype(START(a))>, decltype(start(a))::size_s()> aa = a;
  252. Small<std::decay_t<decltype(START(b))>, decltype(start(b))::size_s()> bb = b;
  253. using Ua = decltype(aa);
  254. using Ub = decltype(bb);
  255. typename fromrank1<fun::Wedge<D, Oa, Ob>, Ua, Ub>::type r;
  256. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  257. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  258. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  259. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  260. return r;
  261. }
  262. #undef DECL_WEDGE
  263. #define DECL_WEDGE(condition) \
  264. template <int D, int Oa, int Ob, class Va, class Vb, class Vr, \
  265. std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
  266. void wedge(Va const & a, Vb const & b, Vr & r)
  267. DECL_WEDGE(general_case)
  268. {
  269. Small<std::decay_t<decltype(START(a))>, decltype(start(a))::size_s()> aa = a;
  270. Small<std::decay_t<decltype(START(b))>, decltype(start(b))::size_s()> bb = b;
  271. using Ua = decltype(aa);
  272. using Ub = decltype(bb);
  273. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  274. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  275. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  276. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  277. }
  278. #undef DECL_WEDGE
  279. template <class A, class B, std::enable_if_t<std::decay_t<A>::size_s()==2 && std::decay_t<B>::size_s()==2, int> =0>
  280. inline auto cross(A const & a_, B const & b_)
  281. {
  282. Small<std::decay_t<decltype(START(a_))>, 2> a = a_;
  283. Small<std::decay_t<decltype(START(b_))>, 2> b = b_;
  284. Small<std::decay_t<decltype(START(a_) * START(b_))>, 1> r;
  285. fun::Wedge<2, 1, 1>::product(a, b, r);
  286. return r[0];
  287. }
  288. template <class A, class B, std::enable_if_t<std::decay_t<A>::size_s()==3 && std::decay_t<B>::size_s()==3, int> =0>
  289. inline auto cross(A const & a_, B const & b_)
  290. {
  291. Small<std::decay_t<decltype(START(a_))>, 3> a = a_;
  292. Small<std::decay_t<decltype(START(b_))>, 3> b = b_;
  293. Small<std::decay_t<decltype(START(a_) * START(b_))>, 3> r;
  294. fun::Wedge<3, 1, 1>::product(a, b, r);
  295. return r;
  296. }
  297. template <class V>
  298. inline auto perp(V const & v)
  299. {
  300. static_assert(v.size()==2, "dimension error");
  301. return Small<std::decay_t<decltype(START(v))>, 2> {v[1], -v[0]};
  302. }
  303. template <class V, class U, std::enable_if_t<is_scalar<U>, int> =0>
  304. inline auto perp(V const & v, U const & n)
  305. {
  306. static_assert(v.size()==2, "dimension error");
  307. return Small<std::decay_t<decltype(START(v) * n)>, 2> {v[1]*n, -v[0]*n};
  308. }
  309. template <class V, class U, std::enable_if_t<!is_scalar<U>, int> =0>
  310. inline auto perp(V const & v, U const & n)
  311. {
  312. static_assert(v.size()==3, "dimension error");
  313. return cross(v, n);
  314. }
  315. // --------------------
  316. // Other whole-array ops.
  317. // --------------------
  318. template <class A, std::enable_if_t<is_slice<A>, int> =0>
  319. inline auto normv(A const & a)
  320. {
  321. return concrete(a/norm2(a));
  322. }
  323. template <class A, std::enable_if_t<!is_slice<A> && is_ra<A>, int> =0>
  324. inline auto normv(A const & a)
  325. {
  326. auto b = concrete(a);
  327. b /= norm2(b);
  328. return b;
  329. }
  330. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  331. template <class A, class B, class C>
  332. inline void
  333. gemm(A const & a, B const & b, C & c)
  334. {
  335. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  336. }
  337. // FIXME branch gemm on Ryn::size_s(), but that's bugged.
  338. #define MMTYPE decltype(from(times(), a(ra::all, 0), b(0, ra::all)))
  339. // default for row-major x row-major. See bench-gemm.C for variants.
  340. template <class S, class T>
  341. inline auto
  342. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  343. {
  344. int const M = a.size(0);
  345. int const N = b.size(1);
  346. int const K = a.size(1);
  347. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  348. auto c = with_shape<MMTYPE>({M, N}, ra::unspecified);
  349. for (int k=0; k<K; ++k) {
  350. c += from(times(), a(ra::all, k), b(k, ra::all));
  351. }
  352. return c;
  353. }
  354. // we still want the Small version to be different.
  355. template <class A, class B>
  356. inline ra::Small<std::decay_t<decltype(START(std::declval<A>()) * START(std::declval<B>()))>, A::size(0), B::size(1)>
  357. gemm(A const & a, B const & b)
  358. {
  359. constexpr int M = a.size(0);
  360. constexpr int N = b.size(1);
  361. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  362. auto c = with_shape<MMTYPE>({M, N}, ra::unspecified);
  363. for (int i=0; i<M; ++i) {
  364. for (int j=0; j<N; ++j) {
  365. c(i, j) = dot(a(i), b(ra::all, j));
  366. }
  367. }
  368. return c;
  369. }
  370. #undef MMTYPE
  371. template <class A, class B>
  372. inline auto
  373. gevm(A const & a, B const & b)
  374. {
  375. int const M = b.size(0);
  376. int const N = b.size(1);
  377. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  378. auto c = with_shape<decltype(a[0]*b(0, ra::all))>({N}, 0);
  379. for (int i=0; i<M; ++i) {
  380. c += a[i]*b(i);
  381. }
  382. return c;
  383. }
  384. template <class A, class B>
  385. inline auto
  386. gemv(A const & a, B const & b)
  387. {
  388. int const M = a.size(0);
  389. int const N = a.size(1);
  390. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  391. auto c = with_shape<decltype(a(ra::all, 0)*b[0])>({M}, 0);
  392. for (int j=0; j<N; ++j) {
  393. c += a(ra::all, j) * b[j];
  394. }
  395. return c;
  396. }
  397. #undef START
  398. // TODO adjust type for var rank/var size/fixed size. Check temporary.H.
  399. template <class X> auto make_array(X && x) { return ra::Big<ra::value_t<X>, decltype(start(x))::rank_s()>(x); }
  400. } // namespace ra
  401. #undef CHECK_BOUNDS
  402. #undef RA_CHECK_BOUNDS_RA_OPERATORS