operators.H 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file operators.H
  3. /// @brief Sugar for ra:: expression templates.
  4. // (c) Daniel Llorens - 2014-2017
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. // FIXME Dependence on specific ra:: types should maybe be elsewhere.
  11. #include "ra/global.H"
  12. #include "ra/complex.H"
  13. #include "ra/wrank.H"
  14. #include "ra/pick.H"
  15. #include "ra/view-ops.H"
  16. #if defined(RA_CHECK_BOUNDS) && RA_CHECK_BOUNDS==0
  17. #define CHECK_BOUNDS( cond )
  18. #else
  19. #define CHECK_BOUNDS( cond ) RA_ASSERT( cond, 0 )
  20. #endif
  21. #include "ra/optimize.H"
  22. #ifndef RA_OPTIMIZE
  23. #define RA_OPTIMIZE 1 // enabled by default
  24. #endif
  25. #if RA_OPTIMIZE==1
  26. #define OPTIMIZE optimize
  27. #else
  28. #define OPTIMIZE
  29. #endif
  30. namespace ra {
  31. template <int ... Iarg, class A> inline constexpr
  32. decltype(auto) transpose(mp::int_list<Iarg ...>, A && a)
  33. {
  34. return transpose<Iarg ...>(std::forward<A>(a));
  35. }
  36. namespace {
  37. template <class A> inline constexpr
  38. decltype(auto) FLAT(A && a)
  39. {
  40. return *(ra::start(std::forward<A>(a)).flat());
  41. }
  42. } // namespace
  43. // I considered three options for lookup.
  44. // 1. define these in a class that ArrayIterator or Container or Slice types derive from. This was done for an old library I had (vector-ops.H). It results in the narrowest scope, but since those types are used in the definition (ra::Expr is an ArrayIterator), it requires lots of forwarding and traits:: .
  45. // 2. raw ADL doesn't work because some ra:: types use ! != etc for different things (e.g. Flat). Possible solution: don't ever use + != == for Flat.
  46. // 3. enable_if'd ADL is what you see here.
  47. // --------------------------------
  48. // Array versions of operators and functions
  49. // --------------------------------
  50. // We need the zero/scalar specializations because the scalar/scalar operators
  51. // maybe be templated (e.g. complex<>), so they won't be found when an implicit
  52. // conversion from zero->scalar is also needed. That is, without those
  53. // specializations, ra::View<complex, 0> * complex will fail.
  54. // These depend on OPNAME defined in optimize.H and used there to match ET patterns.
  55. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  56. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
  57. inline constexpr auto operator OP(A && a, B && b) \
  58. { \
  59. return OPTIMIZE(map(OPNAME(), std::forward<A>(a), std::forward<B>(b))); \
  60. } \
  61. template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
  62. inline constexpr auto operator OP(A && a, B && b) \
  63. { \
  64. return FLAT(a) OP FLAT(b); \
  65. }
  66. DEF_NAMED_BINARY_OP(+, plus)
  67. DEF_NAMED_BINARY_OP(-, minus)
  68. DEF_NAMED_BINARY_OP(*, times)
  69. DEF_NAMED_BINARY_OP(/, slash)
  70. #undef DEF_NAMED_BINARY_OP
  71. #define DEF_BINARY_OP(OP) \
  72. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
  73. inline auto operator OP(A && a, B && b) \
  74. { \
  75. return map([](auto && a, auto && b) { return a OP b; }, \
  76. std::forward<A>(a), std::forward<B>(b)); \
  77. } \
  78. template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
  79. inline auto operator OP(A && a, B && b) \
  80. { \
  81. return FLAT(a) OP FLAT(b); \
  82. }
  83. FOR_EACH(DEF_BINARY_OP, >, <, >=, <=, ==, !=, |, &, ^)
  84. #undef DEF_BINARY_OP
  85. #define DEF_UNARY_OP(OP) \
  86. template <class A, std::enable_if_t<is_ra_pos_rank<A>, int> =0> \
  87. inline auto operator OP(A && a) \
  88. { \
  89. return map([](auto && a) { return OP a; }, std::forward<A>(a)); \
  90. }
  91. FOR_EACH(DEF_UNARY_OP, !, +, -) // TODO Make + into nop.
  92. #undef DEF_UNARY_OP
  93. // When OP(a) isn't found from ra::, the deduction from rank(0) -> scalar doesn't work.
  94. // TODO Cf [ref:examples/useret.C:0].
  95. #define DEF_NAME_OP(OP) \
  96. using ::OP; \
  97. template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
  98. inline auto OP(A && ... a) \
  99. { \
  100. return map([](auto && ... a) { return OP(a ...); }, std::forward<A>(a) ...); \
  101. } \
  102. template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
  103. inline auto OP(A && ... a) \
  104. { \
  105. return OP(FLAT(a) ...); \
  106. }
  107. FOR_EACH(DEF_NAME_OP, rel_error, pow, xI, conj, sqr, sqrm, sqrt, cos, sin)
  108. FOR_EACH(DEF_NAME_OP, exp, expm1, log, log1p, log10, isfinite, isnan, isinf)
  109. FOR_EACH(DEF_NAME_OP, max, min, abs, odd, asin, acos, atan, atan2, clamp)
  110. FOR_EACH(DEF_NAME_OP, cosh, sinh, tanh, arg)
  111. #undef DEF_NAME_OP
  112. #define DEF_NAME_OP(OP) \
  113. using ::OP; \
  114. template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
  115. inline auto OP(A && ... a) \
  116. { \
  117. return map([](auto && ... a) -> decltype(auto) { return OP(a ...); }, std::forward<A>(a) ...); \
  118. } \
  119. template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
  120. inline decltype(auto) OP(A && ... a) \
  121. { \
  122. return OP(FLAT(a) ...); \
  123. }
  124. FOR_EACH(DEF_NAME_OP, real_part, imag_part)
  125. #undef DEF_NAME_OP
  126. template <class T, class A>
  127. inline auto cast(A && a)
  128. {
  129. return map([](auto && a) { return T(a); }, std::forward<A>(a));
  130. }
  131. // TODO could be useful to deduce T as tuple of value_types (&).
  132. template <class T, class ... A>
  133. inline auto pack(A && ... a)
  134. {
  135. return map([](auto && ... a) { return T { a ... }; }, std::forward<A>(a) ...);
  136. }
  137. // FIXME Inelegant story wrt plain array / nested array :-/
  138. template <class A, class I>
  139. inline auto at(A && a, I && i)
  140. {
  141. return map([&a](auto && i) -> decltype(auto) { return a.at(i); }, i);
  142. }
  143. template <class W, class T, class F,
  144. std::enable_if_t<is_ra_pos_rank<W> || is_ra_pos_rank<T> || is_ra_pos_rank<F>, int> =0>
  145. inline auto where(W && w, T && t, F && f)
  146. {
  147. return pick(cast<bool>(start(std::forward<W>(w))), start(std::forward<F>(f)), start(std::forward<T>(t)));
  148. }
  149. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0>
  150. inline auto operator &&(A && a, B && b)
  151. {
  152. return where(std::forward<A>(a), cast<bool>(std::forward<B>(b)), false);
  153. }
  154. template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0>
  155. inline auto operator ||(A && a, B && b)
  156. {
  157. return where(std::forward<A>(a), true, cast<bool>(std::forward<B>(b)));
  158. }
  159. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  160. template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
  161. inline auto operator OP(A && a, B && b) \
  162. { \
  163. return FLAT(a) OP FLAT(b); \
  164. }
  165. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
  166. #undef DEF_SHORTCIRCUIT_BINARY_OP
  167. // --------------------------------
  168. // Some whole-array reductions.
  169. // TODO First rank reductions? Variable rank reductions?
  170. // --------------------------------
  171. template <class A> inline bool
  172. any(A && a)
  173. {
  174. return early(map([](bool x) { return std::make_tuple(x, x); }, std::forward<A>(a)), false);
  175. }
  176. template <class A> inline bool
  177. every(A && a)
  178. {
  179. return early(map([](bool x) { return std::make_tuple(!x, x); }, std::forward<A>(a)), true);
  180. }
  181. // FIXME variable rank? see J 'index of' (x i. y), etc.
  182. template <class A>
  183. inline auto index(A && a)
  184. {
  185. return early(map([](auto && a, auto && i) { return std::make_tuple(bool(a), i); },
  186. std::forward<A>(a), ra::iota(start(a).size(0))),
  187. ra::dim_t(-1));
  188. }
  189. // [ma108]
  190. template <class A, class B>
  191. inline bool lexicographical_compare(A && a, B && b)
  192. {
  193. return early(map([](auto && a, auto && b)
  194. { return a==b ? std::make_tuple(false, true) : std::make_tuple(true, a<b); },
  195. a, b),
  196. false);
  197. }
  198. // FIXME only works with numeric types.
  199. using std::min;
  200. template <class A>
  201. inline auto amin(A && a)
  202. {
  203. using T = value_t<A>;
  204. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  205. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  206. return c;
  207. }
  208. using std::max;
  209. template <class A>
  210. inline auto amax(A && a)
  211. {
  212. using T = value_t<A>;
  213. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  214. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  215. return c;
  216. }
  217. // FIXME encapsulate this kind of reference-reduction.
  218. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  219. template <class A, class Less = std::less<value_t<A>>>
  220. inline decltype(auto) refmin(A && a, Less && less = std::less<value_t<A>>())
  221. {
  222. CHECK_BOUNDS(a.size()>0);
  223. decltype(auto) s = ra::start(a);
  224. auto p = &(*s.flat());
  225. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  226. return *p;
  227. }
  228. template <class A, class Less = std::less<value_t<A>>>
  229. inline decltype(auto) refmax(A && a, Less && less = std::less<value_t<A>>())
  230. {
  231. CHECK_BOUNDS(a.size()>0);
  232. decltype(auto) s = ra::start(a);
  233. auto p = &(*s.flat());
  234. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  235. return *p;
  236. }
  237. template <class A>
  238. inline constexpr auto sum(A && a)
  239. {
  240. value_t<A> c {};
  241. for_each([&c](auto && a) { c += a; }, a);
  242. return c;
  243. }
  244. template <class A>
  245. inline constexpr auto prod(A && a)
  246. {
  247. value_t<A> c(1.);
  248. for_each([&c](auto && a) { c *= a; }, a);
  249. return c;
  250. }
  251. template <class A> inline auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
  252. template <class A> inline auto norm2(A && a) { return std::sqrt(reduce_sqrm(a)); }
  253. template <class A, class B>
  254. inline auto dot(A && a, B && b)
  255. {
  256. std::decay_t<decltype(FLAT(a) * FLAT(b))> c(0.);
  257. for_each([&c](auto && a, auto && b) { c = fma(a, b, c); }, a, b);
  258. return c;
  259. }
  260. template <class A, class B>
  261. inline auto cdot(A && a, B && b)
  262. {
  263. std::decay_t<decltype(conj(FLAT(a)) * FLAT(b))> c(0.);
  264. for_each([&c](auto && a, auto && b) { c = fma_conj(a, b, c); }, a, b);
  265. return c;
  266. }
  267. // --------------------
  268. // Wedge product
  269. // TODO Handle the simplifications dot_plus, yields_scalar, etc. just as vec::wedge does.
  270. // --------------------
  271. template <class A>
  272. struct torank1
  273. {
  274. using type = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  275. };
  276. template <class Wedge, class Va, class Vb, class Enable=void>
  277. struct fromrank1
  278. {
  279. using valtype = typename Wedge::template valtype<Va, Vb>;
  280. using type = std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>>;
  281. };
  282. #define DECL_WEDGE(condition) \
  283. template <int D, int Oa, int Ob, class Va, class Vb, \
  284. std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
  285. decltype(auto) wedge(Va const & a, Vb const & b)
  286. DECL_WEDGE(general_case)
  287. {
  288. Small<std::decay_t<decltype(FLAT(a))>, size_s(a)> aa = a;
  289. Small<std::decay_t<decltype(FLAT(b))>, size_s(b)> bb = b;
  290. using Ua = decltype(aa);
  291. using Ub = decltype(bb);
  292. typename fromrank1<fun::Wedge<D, Oa, Ob>, Ua, Ub>::type r;
  293. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  294. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  295. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  296. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  297. return r;
  298. }
  299. #undef DECL_WEDGE
  300. #define DECL_WEDGE(condition) \
  301. template <int D, int Oa, int Ob, class Va, class Vb, class Vr, \
  302. std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
  303. void wedge(Va const & a, Vb const & b, Vr & r)
  304. DECL_WEDGE(general_case)
  305. {
  306. Small<std::decay_t<decltype(FLAT(a))>, size_s(a)> aa = a;
  307. Small<std::decay_t<decltype(FLAT(b))>, size_s(b)> bb = b;
  308. using Ua = decltype(aa);
  309. using Ub = decltype(bb);
  310. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  311. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  312. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  313. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  314. }
  315. #undef DECL_WEDGE
  316. template <class A, class B, std::enable_if_t<size_s<A>()==2 && size_s<B>()==2, int> =0>
  317. inline auto cross(A const & a_, B const & b_)
  318. {
  319. Small<std::decay_t<decltype(FLAT(a_))>, 2> a = a_;
  320. Small<std::decay_t<decltype(FLAT(b_))>, 2> b = b_;
  321. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 1> r;
  322. fun::Wedge<2, 1, 1>::product(a, b, r);
  323. return r[0];
  324. }
  325. template <class A, class B, std::enable_if_t<size_s<A>()==3 && size_s<B>()==3, int> =0>
  326. inline auto cross(A const & a_, B const & b_)
  327. {
  328. Small<std::decay_t<decltype(FLAT(a_))>, 3> a = a_;
  329. Small<std::decay_t<decltype(FLAT(b_))>, 3> b = b_;
  330. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 3> r;
  331. fun::Wedge<3, 1, 1>::product(a, b, r);
  332. return r;
  333. }
  334. template <class V>
  335. inline auto perp(V const & v)
  336. {
  337. static_assert(v.size()==2, "dimension error");
  338. return Small<std::decay_t<decltype(FLAT(v))>, 2> {v[1], -v[0]};
  339. }
  340. template <class V, class U, std::enable_if_t<is_scalar<U>, int> =0>
  341. inline auto perp(V const & v, U const & n)
  342. {
  343. static_assert(v.size()==2, "dimension error");
  344. return Small<std::decay_t<decltype(FLAT(v) * n)>, 2> {v[1]*n, -v[0]*n};
  345. }
  346. template <class V, class U, std::enable_if_t<!is_scalar<U>, int> =0>
  347. inline auto perp(V const & v, U const & n)
  348. {
  349. static_assert(v.size()==3, "dimension error");
  350. return cross(v, n);
  351. }
  352. // --------------------
  353. // Other whole-array ops.
  354. // --------------------
  355. template <class A, std::enable_if_t<is_slice<A>, int> =0>
  356. inline auto normv(A const & a)
  357. {
  358. return concrete(a/norm2(a));
  359. }
  360. template <class A, std::enable_if_t<!is_slice<A> && is_ra<A>, int> =0>
  361. inline auto normv(A const & a)
  362. {
  363. auto b = concrete(a);
  364. b /= norm2(b);
  365. return b;
  366. }
  367. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  368. template <class A, class B, class C>
  369. inline void
  370. gemm(A const & a, B const & b, C & c)
  371. {
  372. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  373. }
  374. // FIXME branch gemm on Ryn::size_s(), but that's bugged.
  375. #define MMTYPE decltype(from(times(), a(ra::all, 0), b(0, ra::all)))
  376. // default for row-major x row-major. See bench-gemm.C for variants.
  377. template <class S, class T>
  378. inline auto
  379. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  380. {
  381. int const M = a.size(0);
  382. int const N = b.size(1);
  383. int const K = a.size(1);
  384. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  385. auto c = with_shape<MMTYPE>({M, N}, decltype(a(0, 0)*b(0, 0))());
  386. for (int k=0; k<K; ++k) {
  387. c += from(times(), a(ra::all, k), b(k, ra::all));
  388. }
  389. return c;
  390. }
  391. // we still want the Small version to be different.
  392. template <class A, class B>
  393. inline ra::Small<std::decay_t<decltype(FLAT(std::declval<A>()) * FLAT(std::declval<B>()))>, A::size(0), B::size(1)>
  394. gemm(A const & a, B const & b)
  395. {
  396. constexpr int M = a.size(0);
  397. constexpr int N = b.size(1);
  398. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  399. auto c = with_shape<MMTYPE>({M, N}, ra::none);
  400. for (int i=0; i<M; ++i) {
  401. for (int j=0; j<N; ++j) {
  402. c(i, j) = dot(a(i), b(ra::all, j));
  403. }
  404. }
  405. return c;
  406. }
  407. #undef MMTYPE
  408. template <class A, class B>
  409. inline auto
  410. gevm(A const & a, B const & b)
  411. {
  412. int const M = b.size(0);
  413. int const N = b.size(1);
  414. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  415. auto c = with_shape<decltype(a[0]*b(0, ra::all))>({N}, 0);
  416. for (int i=0; i<M; ++i) {
  417. c += a[i]*b(i);
  418. }
  419. return c;
  420. }
  421. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  422. template <class A, class B>
  423. inline auto
  424. gemv(A const & a, B const & b)
  425. {
  426. int const M = a.size(0);
  427. int const N = a.size(1);
  428. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  429. auto c = with_shape<decltype(a(ra::all, 0)*b[0])>({M}, 0);
  430. for (int j=0; j<N; ++j) {
  431. c += a(ra::all, j) * b[j];
  432. }
  433. return c;
  434. }
  435. } // namespace ra
  436. #undef OPTIMIZE
  437. #undef CHECK_BOUNDS