operators.hh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file operators.hh
  3. /// @brief Sugar for ra:: expression templates.
  4. // (c) Daniel Llorens - 2014-2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. // FIXME Dependence on specific ra:: types should maybe be elsewhere.
  11. #include "ra/global.hh"
  12. #include "ra/complex.hh"
  13. #include "ra/wrank.hh"
  14. #include "ra/pick.hh"
  15. #include "ra/view-ops.hh"
  16. #include "ra/optimize.hh"
  17. #ifndef RA_DO_OPT
  18. #define RA_DO_OPT 1 // enabled by default
  19. #endif
  20. #if RA_DO_OPT==1
  21. #define RA_OPT optimize
  22. #else
  23. #define RA_OPT
  24. #endif
  25. namespace ra {
  26. // These ra::start are needed b/c rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  27. template <class T, class F>
  28. requires (ra::is_zero_or_scalar<T> && ra::is_zero_or_scalar<F>)
  29. inline constexpr decltype(auto) where(bool const w, T && t, F && f)
  30. {
  31. return w ? *(ra::start(t).flat()) : *(ra::start(f).flat());
  32. }
  33. template <int D, int Oa, int Ob, class A, class B>
  34. requires (ra::is_scalar<A> && ra::is_scalar<B>)
  35. inline constexpr auto wedge(A const & a, B const & b) { return a*b; }
  36. template <int ... Iarg, class A> inline constexpr
  37. decltype(auto) transpose(mp::int_list<Iarg ...>, A && a)
  38. {
  39. return transpose<Iarg ...>(std::forward<A>(a));
  40. }
  41. namespace {
  42. template <class A> inline constexpr
  43. decltype(auto) FLAT(A && a)
  44. {
  45. return *(ra::start(std::forward<A>(a)).flat());
  46. }
  47. } // namespace
  48. // ---------------------------
  49. // from, after APL, like (from) in guile-ploy
  50. // TODO integrate with is_beatable shortcuts, operator() in the various array types.
  51. // ---------------------------
  52. template <class I>
  53. struct index_rank_
  54. {
  55. using type = mp::int_t<rank_s<I>()>;
  56. static_assert(type::value!=RANK_ANY, "dynamic rank unsupported");
  57. static_assert(size_s<I>()!=DIM_BAD, "undelimited extent subscript unsupported");
  58. };
  59. template <class I> using index_rank = typename index_rank_<I>::type;
  60. template <class II, int drop, class Op> inline constexpr
  61. decltype(auto) from_partial(Op && op)
  62. {
  63. if constexpr (drop==mp::len<II>) {
  64. return std::forward<Op>(op);
  65. } else {
  66. return wrank(mp::append<mp::makelist<drop, mp::int_t<0>>, mp::drop<II, drop>> {},
  67. from_partial<II, drop+1>(std::forward<Op>(op)));
  68. }
  69. }
  70. // TODO we should be able to do better by slicing at each dimension, etc. But verb<> only supports rank-0 for the innermost op.
  71. template <class A, class ... I> inline constexpr
  72. auto from(A && a, I && ... i)
  73. {
  74. if constexpr (0==sizeof...(i)) {
  75. return a();
  76. } else if constexpr (1==sizeof...(i)) {
  77. // support dynamic rank for 1 arg only (see test in test/from.cc).
  78. return expr(std::forward<A>(a), start(std::forward<I>(i) ...));
  79. } else {
  80. using II = mp::map<index_rank, mp::tuple<decltype(start(std::forward<I>(i))) ...>>;
  81. return expr(from_partial<II, 1>(std::forward<A>(a)), start(std::forward<I>(i)) ...);
  82. }
  83. }
  84. // I considered three options for lookup.
  85. // 1. define these in a class that RaIterator or Container or Slice types derive from. This was done for an old library I had (vector-ops.hh). It results in the smallest scope, but since those types are used in the definition (ra::Expr is an RaIterator), it requires lots of forwarding and traits:: .
  86. // 2. raw ADL doesn't work because some ra:: types use ! != etc for different things (e.g. Flat). Possible solution: don't ever use + != == for Flat.
  87. // 3. requires-constrained ADL is what you see here.
  88. // --------------------------------
  89. // Array versions of operators and functions
  90. // --------------------------------
  91. // We need the zero/scalar specializations because the scalar/scalar operators
  92. // maybe be templated (e.g. complex<>), so they won't be found when an implicit
  93. // conversion from zero->scalar is also needed. That is, without those
  94. // specializations, ra::View<complex, 0> * complex will fail.
  95. // These depend on OPNAME defined in optimize.hh and used there to match ET patterns.
  96. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  97. template <class A, class B> \
  98. requires (ra_pos_and_any<A, B>) \
  99. inline constexpr auto operator OP(A && a, B && b) \
  100. { \
  101. return RA_OPT(map(OPNAME(), std::forward<A>(a), std::forward<B>(b))); \
  102. } \
  103. template <class A, class B> \
  104. requires (ra_zero<A, B>) \
  105. inline constexpr auto operator OP(A && a, B && b) \
  106. { \
  107. return FLAT(a) OP FLAT(b); \
  108. }
  109. DEF_NAMED_BINARY_OP(+, plus)
  110. DEF_NAMED_BINARY_OP(-, minus)
  111. DEF_NAMED_BINARY_OP(*, times)
  112. DEF_NAMED_BINARY_OP(/, slash)
  113. #undef DEF_NAMED_BINARY_OP
  114. #define DEF_BINARY_OP(OP) \
  115. template <class A, class B> \
  116. requires (ra_pos_and_any<A, B>) \
  117. inline auto operator OP(A && a, B && b) \
  118. { \
  119. return map([](auto && a, auto && b) { return a OP b; }, \
  120. std::forward<A>(a), std::forward<B>(b)); \
  121. } \
  122. template <class A, class B> \
  123. requires (ra_zero<A, B>) \
  124. inline auto operator OP(A && a, B && b) \
  125. { \
  126. return FLAT(a) OP FLAT(b); \
  127. }
  128. FOR_EACH(DEF_BINARY_OP, >, <, >=, <=, <=>, ==, !=, |, &, ^)
  129. #undef DEF_BINARY_OP
  130. #define DEF_UNARY_OP(OP) \
  131. template <class A> \
  132. requires (ra_pos_and_any<A>) \
  133. inline auto operator OP(A && a) \
  134. { \
  135. return map([](auto && a) { return OP a; }, std::forward<A>(a)); \
  136. }
  137. FOR_EACH(DEF_UNARY_OP, !, +, -) // TODO Make + into nop.
  138. #undef DEF_UNARY_OP
  139. // When OP(a) isn't found from ra::, the deduction from rank(0) -> scalar doesn't work.
  140. // TODO Cf [ref:examples/useret.cc:0].
  141. #define DEF_NAME_OP(OP) \
  142. using ::OP; \
  143. template <class ... A> \
  144. requires (ra_pos_and_any<A ...>) \
  145. inline auto OP(A && ... a) \
  146. { \
  147. return map([](auto && ... a) { return OP(a ...); }, std::forward<A>(a) ...); \
  148. } \
  149. template <class ... A> \
  150. requires (ra_zero<A ...>) \
  151. inline auto OP(A && ... a) \
  152. { \
  153. return OP(FLAT(a) ...); \
  154. }
  155. FOR_EACH(DEF_NAME_OP, rel_error, pow, xI, conj, sqr, sqrm, sqrt, cos, sin)
  156. FOR_EACH(DEF_NAME_OP, exp, expm1, log, log1p, log10, isfinite, isnan, isinf)
  157. FOR_EACH(DEF_NAME_OP, max, min, abs, odd, asin, acos, atan, atan2, clamp)
  158. FOR_EACH(DEF_NAME_OP, cosh, sinh, tanh, arg)
  159. #undef DEF_NAME_OP
  160. #define DEF_NAME_OP(OP) \
  161. using ::OP; \
  162. template <class ... A> \
  163. requires (ra_pos_and_any<A ...>) \
  164. inline auto OP(A && ... a) \
  165. { \
  166. return map([](auto && ... a) -> decltype(auto) { return OP(a ...); }, std::forward<A>(a) ...); \
  167. } \
  168. template <class ... A> \
  169. requires (ra_zero<A ...>) \
  170. inline decltype(auto) OP(A && ... a) \
  171. { \
  172. return OP(FLAT(a) ...); \
  173. }
  174. FOR_EACH(DEF_NAME_OP, real_part, imag_part)
  175. #undef DEF_NAME_OP
  176. template <class T, class A>
  177. inline auto cast(A && a)
  178. {
  179. return map([](auto && a) { return T(a); }, std::forward<A>(a));
  180. }
  181. // TODO could be useful to deduce T as tuple of value_types (&).
  182. template <class T, class ... A>
  183. inline auto pack(A && ... a)
  184. {
  185. return map([](auto && ... a) { return T { a ... }; }, std::forward<A>(a) ...);
  186. }
  187. // FIXME Inelegant story wrt plain array / nested array :-/
  188. template <class A, class I>
  189. inline auto at(A && a, I && i)
  190. {
  191. return map([&a](auto && i) -> decltype(auto) { return a.at(i); }, i);
  192. }
  193. template <class W, class T, class F>
  194. requires (ra_pos_and_any<W, T, F>)
  195. inline auto where(W && w, T && t, F && f)
  196. {
  197. return pick(cast<bool>(start(std::forward<W>(w))), start(std::forward<F>(f)), start(std::forward<T>(t)));
  198. }
  199. template <class A, class B>
  200. requires (ra_pos_and_any<A, B>)
  201. inline auto operator &&(A && a, B && b)
  202. {
  203. return where(std::forward<A>(a), cast<bool>(std::forward<B>(b)), false);
  204. }
  205. template <class A, class B>
  206. requires (ra_pos_and_any<A, B>)
  207. inline auto operator ||(A && a, B && b)
  208. {
  209. return where(std::forward<A>(a), true, cast<bool>(std::forward<B>(b)));
  210. }
  211. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  212. template <class A, class B> \
  213. requires (ra_zero<A, B>) \
  214. inline auto operator OP(A && a, B && b) \
  215. { \
  216. return FLAT(a) OP FLAT(b); \
  217. }
  218. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
  219. #undef DEF_SHORTCIRCUIT_BINARY_OP
  220. // --------------------------------
  221. // Some whole-array reductions.
  222. // TODO First rank reductions? Variable rank reductions?
  223. // --------------------------------
  224. template <class A> inline bool
  225. any(A && a)
  226. {
  227. return early(map([](bool x) { return std::make_tuple(x, x); }, std::forward<A>(a)), false);
  228. }
  229. template <class A> inline bool
  230. every(A && a)
  231. {
  232. return early(map([](bool x) { return std::make_tuple(!x, x); }, std::forward<A>(a)), true);
  233. }
  234. // FIXME variable rank? see J 'index of' (x i. y), etc.
  235. template <class A>
  236. inline auto index(A && a)
  237. {
  238. return early(map([](auto && a, auto && i) { return std::make_tuple(bool(a), i); },
  239. std::forward<A>(a), ra::iota(start(a).size(0))),
  240. ra::dim_t(-1));
  241. }
  242. // [ma108]
  243. template <class A, class B>
  244. inline bool lexicographical_compare(A && a, B && b)
  245. {
  246. return early(map([](auto && a, auto && b)
  247. { return a==b ? std::make_tuple(false, true) : std::make_tuple(true, a<b); },
  248. a, b),
  249. false);
  250. }
  251. // FIXME only works with numeric types.
  252. using std::min;
  253. template <class A>
  254. inline auto amin(A && a)
  255. {
  256. using T = value_t<A>;
  257. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  258. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  259. return c;
  260. }
  261. using std::max;
  262. template <class A>
  263. inline auto amax(A && a)
  264. {
  265. using T = value_t<A>;
  266. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  267. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  268. return c;
  269. }
  270. // FIXME encapsulate this kind of reference-reduction.
  271. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  272. template <class A, class Less = std::less<value_t<A>>>
  273. inline decltype(auto) refmin(A && a, Less && less = std::less<value_t<A>>())
  274. {
  275. RA_CHECK(a.size()>0);
  276. decltype(auto) s = ra::start(a);
  277. auto p = &(*s.flat());
  278. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  279. return *p;
  280. }
  281. template <class A, class Less = std::less<value_t<A>>>
  282. inline decltype(auto) refmax(A && a, Less && less = std::less<value_t<A>>())
  283. {
  284. RA_CHECK(a.size()>0);
  285. decltype(auto) s = ra::start(a);
  286. auto p = &(*s.flat());
  287. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  288. return *p;
  289. }
  290. template <class A>
  291. inline constexpr auto sum(A && a)
  292. {
  293. value_t<A> c {};
  294. for_each([&c](auto && a) { c += a; }, a);
  295. return c;
  296. }
  297. template <class A>
  298. inline constexpr auto prod(A && a)
  299. {
  300. value_t<A> c(1.);
  301. for_each([&c](auto && a) { c *= a; }, a);
  302. return c;
  303. }
  304. template <class A> inline auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
  305. template <class A> inline auto norm2(A && a) { return std::sqrt(reduce_sqrm(a)); }
  306. template <class A, class B>
  307. inline auto dot(A && a, B && b)
  308. {
  309. std::decay_t<decltype(FLAT(a) * FLAT(b))> c(0.);
  310. for_each([&c](auto && a, auto && b) { c = fma(a, b, c); }, a, b);
  311. return c;
  312. }
  313. template <class A, class B>
  314. inline auto cdot(A && a, B && b)
  315. {
  316. std::decay_t<decltype(conj(FLAT(a)) * FLAT(b))> c(0.);
  317. for_each([&c](auto && a, auto && b) { c = fma_conj(a, b, c); }, a, b);
  318. return c;
  319. }
  320. // --------------------
  321. // Wedge product
  322. // TODO Handle the simplifications dot_plus, yields_scalar, etc. just as vec::wedge does.
  323. // --------------------
  324. template <class A>
  325. struct torank1
  326. {
  327. using type = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  328. };
  329. template <class Wedge, class Va, class Vb>
  330. struct fromrank1
  331. {
  332. using valtype = typename Wedge::template valtype<Va, Vb>;
  333. using type = std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>>;
  334. };
  335. #define DECL_WEDGE(condition) \
  336. template <int D, int Oa, int Ob, class Va, class Vb> \
  337. requires (!(is_scalar<Va> && is_scalar<Vb>)) \
  338. decltype(auto) wedge(Va const & a, Vb const & b)
  339. DECL_WEDGE(general_case)
  340. {
  341. Small<std::decay_t<value_t<Va>>, size_s<Va>()> aa = a;
  342. Small<std::decay_t<value_t<Vb>>, size_s<Vb>()> bb = b;
  343. using Ua = decltype(aa);
  344. using Ub = decltype(bb);
  345. typename fromrank1<fun::Wedge<D, Oa, Ob>, Ua, Ub>::type r;
  346. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  347. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  348. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  349. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  350. return r;
  351. }
  352. #undef DECL_WEDGE
  353. #define DECL_WEDGE(condition) \
  354. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> \
  355. requires (!(is_scalar<Va> && is_scalar<Vb>)) \
  356. void wedge(Va const & a, Vb const & b, Vr & r)
  357. DECL_WEDGE(general_case)
  358. {
  359. Small<std::decay_t<value_t<Va>>, size_s<Va>()> aa = a;
  360. Small<std::decay_t<value_t<Vb>>, size_s<Vb>()> bb = b;
  361. using Ua = decltype(aa);
  362. using Ub = decltype(bb);
  363. auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
  364. auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
  365. auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
  366. fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  367. }
  368. #undef DECL_WEDGE
  369. template <class A, class B>
  370. requires (size_s<A>()==2 && size_s<B>()==2)
  371. inline auto cross(A const & a_, B const & b_)
  372. {
  373. Small<std::decay_t<decltype(FLAT(a_))>, 2> a = a_;
  374. Small<std::decay_t<decltype(FLAT(b_))>, 2> b = b_;
  375. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 1> r;
  376. fun::Wedge<2, 1, 1>::product(a, b, r);
  377. return r[0];
  378. }
  379. template <class A, class B>
  380. requires (size_s<A>()==3 && size_s<B>()==3)
  381. inline auto cross(A const & a_, B const & b_)
  382. {
  383. Small<std::decay_t<decltype(FLAT(a_))>, 3> a = a_;
  384. Small<std::decay_t<decltype(FLAT(b_))>, 3> b = b_;
  385. Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 3> r;
  386. fun::Wedge<3, 1, 1>::product(a, b, r);
  387. return r;
  388. }
  389. template <class V>
  390. inline auto perp(V const & v)
  391. {
  392. static_assert(v.size()==2, "dimension error");
  393. return Small<std::decay_t<decltype(FLAT(v))>, 2> {v[1], -v[0]};
  394. }
  395. template <class V, class U>
  396. inline auto perp(V const & v, U const & n)
  397. {
  398. if constexpr (is_scalar<U>) {
  399. static_assert(v.size()==2, "dimension error");
  400. return Small<std::decay_t<decltype(FLAT(v) * n)>, 2> {v[1]*n, -v[0]*n};
  401. } else {
  402. static_assert(v.size()==3, "dimension error");
  403. return cross(v, n);
  404. }
  405. }
  406. // --------------------
  407. // Other whole-array ops.
  408. // --------------------
  409. template <class A>
  410. requires (is_slice<A>)
  411. inline auto normv(A const & a)
  412. {
  413. return concrete(a/norm2(a));
  414. }
  415. template <class A>
  416. requires (!is_slice<A> && is_ra<A>)
  417. inline auto normv(A const & a)
  418. {
  419. auto b = concrete(a);
  420. b /= norm2(b);
  421. return b;
  422. }
  423. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  424. template <class A, class B, class C>
  425. inline void
  426. gemm(A const & a, B const & b, C & c)
  427. {
  428. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  429. }
  430. // FIXME branch gemm on Ryn::size_s(), but that's bugged.
  431. #define MMTYPE decltype(from(times(), a(ra::all, 0), b(0, ra::all)))
  432. // default for row-major x row-major. See bench-gemm.cc for variants.
  433. template <class S, class T>
  434. inline auto
  435. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  436. {
  437. int const M = a.size(0);
  438. int const N = b.size(1);
  439. int const K = a.size(1);
  440. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  441. auto c = with_shape<MMTYPE>({M, N}, decltype(a(0, 0)*b(0, 0))());
  442. for (int k=0; k<K; ++k) {
  443. c += from(times(), a(ra::all, k), b(k, ra::all));
  444. }
  445. return c;
  446. }
  447. // we still want the Small version to be different.
  448. template <class A, class B>
  449. inline ra::Small<std::decay_t<decltype(FLAT(std::declval<A>()) * FLAT(std::declval<B>()))>, A::size(0), B::size(1)>
  450. gemm(A const & a, B const & b)
  451. {
  452. constexpr int M = a.size(0);
  453. constexpr int N = b.size(1);
  454. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  455. auto c = with_shape<MMTYPE>({M, N}, ra::none);
  456. for (int i=0; i<M; ++i) {
  457. for (int j=0; j<N; ++j) {
  458. c(i, j) = dot(a(i), b(ra::all, j));
  459. }
  460. }
  461. return c;
  462. }
  463. #undef MMTYPE
  464. template <class A, class B>
  465. inline auto
  466. gevm(A const & a, B const & b)
  467. {
  468. int const M = b.size(0);
  469. int const N = b.size(1);
  470. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  471. auto c = with_shape<decltype(a[0]*b(0, ra::all))>({N}, 0);
  472. for (int i=0; i<M; ++i) {
  473. c += a[i]*b(i);
  474. }
  475. return c;
  476. }
  477. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  478. template <class A, class B>
  479. inline auto
  480. gemv(A const & a, B const & b)
  481. {
  482. int const M = a.size(0);
  483. int const N = a.size(1);
  484. // no with_same_shape b/c cannot index 0 for type if A/B are empty
  485. auto c = with_shape<decltype(a(ra::all, 0)*b[0])>({M}, 0);
  486. for (int j=0; j<N; ++j) {
  487. c += a(ra::all, j) * b[j];
  488. }
  489. return c;
  490. }
  491. } // namespace ra
  492. #undef RA_OPT