ra.hh 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operator overloads for expression templates, and root header.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "big.hh"
  10. #include "optimize.hh"
  11. #include <cmath>
  12. #include <complex>
  13. #ifndef RA_DO_OPT
  14. #define RA_DO_OPT 1 // enabled by default
  15. #endif
  16. #if RA_DO_OPT==1
  17. #define RA_OPT optimize
  18. #else
  19. #define RA_OPT
  20. #endif
  21. // Enable ADL with explicit template args. See http://stackoverflow.com/questions/9838862.
  22. template <class A> constexpr void transpose(ra::noarg);
  23. template <int A> constexpr void iter(ra::noarg);
  24. // ---------------------------
  25. // scalar overloads
  26. // ---------------------------
  27. // abs() needs no qualifying for ra:: types (ADL), shouldn't need it on pods either. FIXME maybe let user decide.
  28. // std::max/min are special, see DEF_NAME in ra.hh.
  29. using std::max, std::min, std::abs, std::fma, std::sqrt, std::pow, std::exp, std::swap,
  30. std::isfinite, std::isinf, std::isnan, std::clamp, std::lerp, std::conj, std::expm1;
  31. #define FOR_FLOAT(T) \
  32. constexpr T conj(T x) { return x; } \
  33. FOR_EACH(FOR_FLOAT, float, double)
  34. #undef FOR_FLOAT
  35. #define FOR_FLOAT(R, C) \
  36. constexpr C \
  37. fma(C const & a, C const & b, C const & c) \
  38. { \
  39. return C(fma(a.real(), b.real(), fma(-a.imag(), b.imag(), c.real())), \
  40. fma(a.real(), b.imag(), fma(a.imag(), b.real(), c.imag()))); \
  41. } \
  42. constexpr bool isfinite(C z) { return isfinite(z.real()) && isfinite(z.imag()); } \
  43. constexpr bool isnan(C z) { return isnan(z.real()) || isnan(z.imag()); } \
  44. constexpr bool isinf(C z) { return (isinf(z.real()) || isinf(z.imag())) && !isnan(z); }
  45. FOR_FLOAT(float, std::complex<float>)
  46. FOR_FLOAT(double, std::complex<double>)
  47. #undef FOR_FLOAT
  48. namespace ra {
  49. // As an array op; special definitions for rank 0.
  50. template <class T> constexpr bool ra_is_real = std::numeric_limits<T>::is_integer || std::is_floating_point_v<T>;
  51. template <class T> requires (ra_is_real<T>) constexpr T amax(T const & x) { return x; }
  52. template <class T> requires (ra_is_real<T>) constexpr T amin(T const & x) { return x; }
  53. template <class T> requires (ra_is_real<T>) constexpr T sqr(T const & x) { return x*x; }
  54. #define FOR_FLOAT(T) \
  55. constexpr T arg(T x) { return T(0); } \
  56. constexpr T conj(T x) { return x; } \
  57. constexpr T mul_conj(T x, T y) { return x*y; } \
  58. constexpr T sqrm(T x) { return sqr(x); } \
  59. constexpr T sqrm(T x, T y) { return sqr(x-y); } \
  60. constexpr T dot(T x, T y) { return x*y; } \
  61. constexpr T fma_conj(T a, T b, T c) { return fma(a, b, c); } \
  62. constexpr T norm2(T x) { return std::abs(x); } \
  63. constexpr T norm2(T x, T y) { return std::abs(x-y); } \
  64. constexpr T rel_error(T a, T b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  65. constexpr T const & real_part(T const & x) { return x; } \
  66. constexpr T & real_part(T & x) { return x; } \
  67. constexpr T imag_part(T x) { return T(0); }
  68. FOR_EACH(FOR_FLOAT, float, double)
  69. #undef FOR_FLOAT
  70. // FIXME few still inline should eventually be constexpr.
  71. #define FOR_FLOAT(R, C) \
  72. inline R arg(C x) { return std::arg(x); } \
  73. constexpr C sqr(C x) { return x*x; } \
  74. constexpr C dot(C x, C y) { return x*y; } \
  75. constexpr C xI(R x) { return C(0, x); } \
  76. constexpr C xI(C z) { return C(-z.imag(), z.real()); } \
  77. constexpr R real_part(C const & z) { return z.real(); } \
  78. constexpr R imag_part(C const & z) { return z.imag(); } \
  79. inline R & real_part(C & z) { return reinterpret_cast<R *>(&z)[0]; } \
  80. inline R & imag_part(C & z) { return reinterpret_cast<R *>(&z)[1]; } \
  81. constexpr R sqrm(C x) { return sqr(x.real())+sqr(x.imag()); } \
  82. constexpr R sqrm(C x, C y) { return sqr(x.real()-y.real())+sqr(x.imag()-y.imag()); } \
  83. constexpr R norm2(C x) { return hypot(x.real(), x.imag()); } \
  84. constexpr R norm2(C x, C y) { return sqrt(sqrm(x, y)); } \
  85. inline R rel_error(C a, C b) { auto den = (abs(a)+abs(b)); return den==0 ? 0. : 2.*norm2(a, b)/den; } \
  86. /* conj(a) * b + c */ \
  87. constexpr C \
  88. fma_conj(C const & a, C const & b, C const & c) \
  89. { \
  90. return C(fma(a.real(), b.real(), fma(a.imag(), b.imag(), c.real())), \
  91. fma(a.real(), b.imag(), fma(-a.imag(), b.real(), c.imag()))); \
  92. } \
  93. /* conj(a) * b */ \
  94. constexpr C \
  95. mul_conj(C const & a, C const & b) \
  96. { \
  97. return C(a.real()*b.real()+a.imag()*b.imag(), \
  98. a.real()*b.imag()-a.imag()*b.real()); \
  99. }
  100. FOR_FLOAT(float, std::complex<float>)
  101. FOR_FLOAT(double, std::complex<double>)
  102. #undef FOR_FLOAT
  103. template <class T> constexpr bool is_scalar_def<std::complex<T>> = true;
  104. template <int ... Iarg, class A>
  105. constexpr decltype(auto)
  106. transpose(mp::int_list<Iarg ...>, A && a)
  107. {
  108. return transpose<Iarg ...>(RA_FWD(a));
  109. }
  110. constexpr bool odd(unsigned int N) { return N & 1; }
  111. // ---------------------------
  112. // outer product
  113. // ---------------------------
  114. template <class II, int drop, class Op>
  115. constexpr decltype(auto)
  116. from_partial(Op && op)
  117. {
  118. if constexpr (drop==mp::len<II>) {
  119. return RA_FWD(op);
  120. } else {
  121. return wrank(mp::append<mp::makelist<drop, ic_t<0>>, mp::drop<II, drop>> {},
  122. from_partial<II, drop+1>(RA_FWD(op)));
  123. }
  124. }
  125. // TODO should be able to do better by slicing at each dimension, etc. But verb<>'s innermost op must be rank 0.
  126. template <class A, class ... I>
  127. constexpr decltype(auto)
  128. from(A && a, I && ... i)
  129. {
  130. if constexpr (0==sizeof...(i)) {
  131. return RA_FWD(a)();
  132. } else if constexpr (1==sizeof...(i)) {
  133. // support dynamic rank for 1 arg only (see test in test/from.cc).
  134. return map(RA_FWD(a), RA_FWD(i) ...);
  135. } else {
  136. return map(from_partial<mp::tuple<ic_t<rank_s<I>()> ...>, 1>(RA_FWD(a)), RA_FWD(i) ...);
  137. }
  138. }
  139. // --------------------------------
  140. // Array versions of operators and functions
  141. // --------------------------------
  142. // We need zero/scalar specializations because the scalar/scalar operators maybe be templated (e.g. complex<>), so they won't be found when an implicit conversion to scalar is also needed, and e.g. ra::View<complex, 0> * complex would fail.
  143. // The function objects are matched in optimize.hh.
  144. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  145. template <class A, class B> requires (tomap<A, B>) constexpr auto \
  146. operator OP(A && a, B && b) \
  147. { return RA_OPT(map(OPNAME(), RA_FWD(a), RA_FWD(b))); } \
  148. template <class A, class B> requires (toreduce<A, B>) constexpr auto \
  149. operator OP(A && a, B && b) \
  150. { return VALUE(RA_FWD(a)) OP VALUE(RA_FWD(b)); }
  151. DEF_NAMED_BINARY_OP(+, std::plus<>) DEF_NAMED_BINARY_OP(-, std::minus<>)
  152. DEF_NAMED_BINARY_OP(*, std::multiplies<>) DEF_NAMED_BINARY_OP(/, std::divides<>)
  153. DEF_NAMED_BINARY_OP(==, std::equal_to<>) DEF_NAMED_BINARY_OP(>, std::greater<>)
  154. DEF_NAMED_BINARY_OP(<, std::less<>) DEF_NAMED_BINARY_OP(>=, std::greater_equal<>)
  155. DEF_NAMED_BINARY_OP(<=, std::less_equal<>) DEF_NAMED_BINARY_OP(!=, std::not_equal_to<>)
  156. DEF_NAMED_BINARY_OP(|, std::bit_or<>) DEF_NAMED_BINARY_OP(&, std::bit_and<>)
  157. DEF_NAMED_BINARY_OP(^, std::bit_xor<>) DEF_NAMED_BINARY_OP(<=>, std::compare_three_way)
  158. #undef DEF_NAMED_BINARY_OP
  159. // FIXME address sanitizer complains in bench-optimize.cc if we use std::identity. Maybe false positive
  160. struct unaryplus
  161. {
  162. template <class T> constexpr /* static P1169 in gcc13 */ auto
  163. operator()(T && t) const noexcept { return RA_FWD(t); }
  164. };
  165. #define DEF_NAMED_UNARY_OP(OP, OPNAME) \
  166. template <class A> requires (tomap<A>) constexpr auto \
  167. operator OP(A && a) \
  168. { return map(OPNAME(), RA_FWD(a)); } \
  169. template <class A> requires (toreduce<A>) constexpr auto \
  170. operator OP(A && a) \
  171. { return OP VALUE(RA_FWD(a)); }
  172. DEF_NAMED_UNARY_OP(+, unaryplus)
  173. DEF_NAMED_UNARY_OP(-, std::negate<>)
  174. DEF_NAMED_UNARY_OP(!, std::logical_not<>)
  175. #undef DEF_NAMED_UNARY_OP
  176. // if OP(a) isn't found in ra::, deduction rank(0) -> scalar doesn't work. TODO Cf useret.cc, reexported.cc
  177. #define DEF_NAME(OP) \
  178. template <class ... A> requires (tomap<A ...>) constexpr auto \
  179. OP(A && ... a) \
  180. { return map([](auto && ... a) -> decltype(auto) { return OP(RA_FWD(a) ...); }, RA_FWD(a) ...); } \
  181. template <class ... A> requires (toreduce<A ...>) constexpr decltype(auto) \
  182. OP(A && ... a) \
  183. { return OP(VALUE(RA_FWD(a)) ...); }
  184. #define DEF_FWD(QUALIFIED_OP, OP) \
  185. template <class ... A> requires (!tomap<A ...> && !toreduce<A ...>) constexpr decltype(auto) \
  186. OP(A && ... a) \
  187. { return QUALIFIED_OP(RA_FWD(a) ...); } \
  188. DEF_NAME(OP)
  189. #define DEF_USING(QUALIFIED_OP, OP) \
  190. using QUALIFIED_OP; \
  191. DEF_NAME(OP)
  192. FOR_EACH(DEF_NAME, odd, arg, sqr, sqrm, real_part, imag_part, xI, rel_error)
  193. // can't DEF_USING bc std::max will gobble ra:: objects if passed by const & (!)
  194. // FIXME define own global max/min overloads for basic types. std::max seems too much of a special case to be usinged.
  195. #define DEF_GLOBAL(f) DEF_FWD(::f, f)
  196. FOR_EACH(DEF_GLOBAL, max, min)
  197. #undef DEF_GLOBAL
  198. // don't use DEF_FWD for these bc we want to allow ADL, e.g. for exp(dual).
  199. #define DEF_GLOBAL(f) DEF_USING(::f, f)
  200. FOR_EACH(DEF_GLOBAL, pow, conj, sqrt, exp, expm1, log, log1p, log10, isfinite, isnan, isinf, atan2)
  201. FOR_EACH(DEF_GLOBAL, abs, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, clamp, lerp)
  202. #undef DEF_GLOBAL
  203. #undef DEF_USING
  204. #undef DEF_FWD
  205. #undef DEF_NAME
  206. template <class T, class A>
  207. constexpr auto
  208. cast(A && a)
  209. {
  210. return map([](auto && b) -> decltype(auto) { return T(b); }, RA_FWD(a));
  211. }
  212. // TODO std::forward_as_tuple?
  213. template <class T, class ... A>
  214. constexpr auto
  215. pack(A && ... a)
  216. {
  217. return map([](auto && ... a) { return T { a ... }; }, RA_FWD(a) ...);
  218. }
  219. // FIXME needs nested array for I
  220. template <class A, class I>
  221. constexpr auto
  222. at(A && a, I && i)
  223. {
  224. return map([a = std::tuple<A>(RA_FWD(a))] (auto && i) -> decltype(auto) { return std::get<0>(a).at(i); },
  225. RA_FWD(i));
  226. }
  227. // --------------------------------
  228. // selection / shortcutting
  229. // --------------------------------
  230. // ra::start are needed bc rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  231. template <class T, class F> requires (toreduce<T, F>)
  232. constexpr decltype(auto)
  233. where(bool const w, T && t, F && f)
  234. {
  235. return w ? VALUE(t) : VALUE(f);
  236. }
  237. template <class W, class T, class F> requires (tomap<W, T, F>)
  238. constexpr auto
  239. where(W && w, T && t, F && f)
  240. {
  241. return pick(cast<bool>(RA_FWD(w)), RA_FWD(f), RA_FWD(t));
  242. }
  243. // catch all for non-ra types.
  244. template <class T, class F> requires (!(tomap<T, F>) && !(toreduce<T, F>))
  245. constexpr decltype(auto)
  246. where(bool const w, T && t, F && f)
  247. {
  248. return w ? t : f;
  249. }
  250. template <class A, class B> requires (tomap<A, B>)
  251. constexpr auto
  252. operator &&(A && a, B && b)
  253. {
  254. return where(RA_FWD(a), cast<bool>(RA_FWD(b)), false);
  255. }
  256. template <class A, class B> requires (tomap<A, B>)
  257. constexpr auto
  258. operator ||(A && a, B && b)
  259. {
  260. return where(RA_FWD(a), true, cast<bool>(RA_FWD(b)));
  261. }
  262. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  263. template <class A, class B> requires (toreduce<A, B>) \
  264. constexpr auto operator OP(A && a, B && b) \
  265. { \
  266. return VALUE(a) OP VALUE(b); \
  267. }
  268. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
  269. #undef DEF_SHORTCIRCUIT_BINARY_OP
  270. // --------------------------------
  271. // Some whole-array reductions. TODO First rank reductions? Variable rank reductions?
  272. // FIXME C++23 and_then/or_else/etc
  273. // --------------------------------
  274. constexpr bool
  275. any(auto && a)
  276. {
  277. return early(map([](bool x) { return x ? std::make_optional(true) : std::nullopt; }, RA_FWD(a)), false);
  278. }
  279. constexpr bool
  280. every(auto && a)
  281. {
  282. return early(map([](bool x) { return !x ? std::make_optional(false) : std::nullopt; }, RA_FWD(a)), true);
  283. }
  284. // FIXME variable rank? see J 'index of' (x i. y), etc.
  285. constexpr auto
  286. index(auto && a)
  287. {
  288. return early(map([](auto && a, auto && i) { return bool(a) ? std::make_optional(i) : std::nullopt; },
  289. RA_FWD(a), ra::iota(ra::start(a).len(0))),
  290. ra::dim_t(-1));
  291. }
  292. // [ma108]
  293. constexpr bool
  294. lexicographical_compare(auto && a, auto && b)
  295. {
  296. return early(map([](auto && a, auto && b) { return a==b ? std::nullopt : std::make_optional(a<b); },
  297. RA_FWD(a), RA_FWD(b)),
  298. false);
  299. }
  300. template <class A>
  301. constexpr auto
  302. amin(A && a)
  303. {
  304. using std::min;
  305. using T = value_t<A>;
  306. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  307. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  308. return c;
  309. }
  310. template <class A>
  311. constexpr auto
  312. amax(A && a)
  313. {
  314. using std::max;
  315. using T = value_t<A>;
  316. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  317. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  318. return c;
  319. }
  320. // FIXME encapsulate this kind of reference-reduction.
  321. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  322. template <class A, class Less = std::less<value_t<A>>>
  323. constexpr decltype(auto)
  324. refmin(A && a, Less && less = std::less<value_t<A>>())
  325. {
  326. RA_CHECK(a.size()>0);
  327. decltype(auto) s = ra::start(a);
  328. auto p = &(*s);
  329. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  330. return *p;
  331. }
  332. template <class A, class Less = std::less<value_t<A>>>
  333. constexpr decltype(auto)
  334. refmax(A && a, Less && less = std::less<value_t<A>>())
  335. {
  336. RA_CHECK(a.size()>0);
  337. decltype(auto) s = ra::start(a);
  338. auto p = &(*s);
  339. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  340. return *p;
  341. }
  342. template <class A>
  343. constexpr auto
  344. sum(A && a)
  345. {
  346. auto c = concrete_type<value_t<A>>(0);
  347. for_each([&c](auto && a) { c += a; }, a);
  348. return c;
  349. }
  350. template <class A>
  351. constexpr auto
  352. prod(A && a)
  353. {
  354. auto c = concrete_type<value_t<A>>(1);
  355. for_each([&c](auto && a) { c *= a; }, a);
  356. return c;
  357. }
  358. constexpr auto reduce_sqrm(auto && a) { return sum(sqrm(a)); }
  359. constexpr auto norm2(auto && a) { return std::sqrt(reduce_sqrm(a)); }
  360. constexpr auto
  361. dot(auto && a, auto && b)
  362. {
  363. std::decay_t<decltype(VALUE(a) * VALUE(b))> c(0.);
  364. for_each([&c](auto && a, auto && b)
  365. {
  366. #ifdef FP_FAST_FMA
  367. c = fma(a, b, c);
  368. #else
  369. c += a*b;
  370. #endif
  371. }, a, b);
  372. return c;
  373. }
  374. constexpr auto
  375. cdot(auto && a, auto && b)
  376. {
  377. std::decay_t<decltype(conj(VALUE(a)) * VALUE(b))> c(0.);
  378. for_each([&c](auto && a, auto && b)
  379. {
  380. #ifdef FP_FAST_FMA
  381. c = fma_conj(a, b, c);
  382. #else
  383. c += conj(a)*b;
  384. #endif
  385. }, a, b);
  386. return c;
  387. }
  388. // --------------------
  389. // Other whole-array ops.
  390. // --------------------
  391. constexpr auto
  392. normv(auto const & a)
  393. {
  394. auto b = concrete(a);
  395. b /= norm2(b);
  396. return b;
  397. }
  398. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  399. constexpr void
  400. gemm(auto const & a, auto const & b, auto & c)
  401. {
  402. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  403. }
  404. #define MMTYPE decltype(from(std::multiplies<>(), a(all, 0), b(0)))
  405. // default for row-major x row-major. See bench-gemm.cc for variants.
  406. template <class S, class T>
  407. constexpr auto
  408. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  409. {
  410. dim_t M=a.len(0), N=b.len(1), K=a.len(1);
  411. // no with_same_shape bc cannot index 0 for type if A/B are empty
  412. auto c = with_shape<MMTYPE>({M, N}, decltype(std::declval<S>()*std::declval<T>())());
  413. for (int k=0; k<K; ++k) {
  414. c += from(std::multiplies<>(), a(all, k), b(k));
  415. }
  416. return c;
  417. }
  418. // we still want the Small version to be different.
  419. template <class A, class B>
  420. constexpr ra::Small<std::decay_t<decltype(VALUE(std::declval<A>()) * VALUE(std::declval<B>()))>, A::len(0), B::len(1)>
  421. gemm(A const & a, B const & b)
  422. {
  423. dim_t M=a.len(0), N=b.len(1);
  424. // no with_same_shape bc cannot index 0 for type if A/B are empty
  425. auto c = with_shape<MMTYPE>({M, N}, ra::none);
  426. for (int i=0; i<M; ++i) {
  427. for (int j=0; j<N; ++j) {
  428. c(i, j) = dot(a(i), b(all, j));
  429. }
  430. }
  431. return c;
  432. }
  433. #undef MMTYPE
  434. constexpr auto
  435. gevm(auto const & a, auto const & b)
  436. {
  437. dim_t M=b.len(0), N=b.len(1);
  438. // no with_same_shape bc cannot index 0 for type if A/B are empty
  439. auto c = with_shape<decltype(a[0]*b(0))>({N}, 0);
  440. for (int i=0; i<M; ++i) {
  441. c += a[i]*b(i);
  442. }
  443. return c;
  444. }
  445. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  446. constexpr auto
  447. gemv(auto const & a, auto const & b)
  448. {
  449. dim_t M=a.len(0), N=a.len(1);
  450. // no with_same_shape bc cannot index 0 for type if A/B are empty
  451. auto c = with_shape<decltype(a(all, 0)*b[0])>({M}, 0);
  452. for (int j=0; j<N; ++j) {
  453. c += a(all, j) * b[j];
  454. }
  455. return c;
  456. }
  457. // --------------------
  458. // Wedge product and cross product
  459. // --------------------
  460. namespace mp {
  461. template <class P, class Plist>
  462. struct FindCombination
  463. {
  464. template <class A> using match = bool_c<0 != PermutationSign<P, A>::value>;
  465. using type = IndexIf<Plist, match>;
  466. constexpr static int where = type::value;
  467. constexpr static int sign = (where>=0) ? PermutationSign<P, typename type::type>::value : 0;
  468. };
  469. // Combination antiC complementary to C wrt [0, 1, ... Dim-1], permuted so [C, antiC] has the sign of [0, 1, ... Dim-1].
  470. template <class C, int D>
  471. struct AntiCombination
  472. {
  473. using EC = complement<C, D>;
  474. static_assert((len<EC>)>=2, "can't correct this complement");
  475. constexpr static int sign = PermutationSign<append<C, EC>, iota<D>>::value;
  476. // Produce permutation of opposite sign if sign<0.
  477. using type = mp::cons<std::tuple_element_t<(sign<0) ? 1 : 0, EC>,
  478. mp::cons<std::tuple_element_t<(sign<0) ? 0 : 1, EC>,
  479. mp::drop<EC, 2>>>;
  480. };
  481. template <class C, int D> struct MapAntiCombination;
  482. template <int D, class ... C>
  483. struct MapAntiCombination<std::tuple<C ...>, D>
  484. {
  485. using type = std::tuple<typename AntiCombination<C, D>::type ...>;
  486. };
  487. template <int D, int O>
  488. struct ChooseComponents
  489. {
  490. static_assert(D>=O, "Bad dimension or form order.");
  491. using type = mp::combinations<iota<D>, O>;
  492. };
  493. template <int D, int O> using ChooseComponents_ = typename ChooseComponents<D, O>::type;
  494. template <int D, int O> requires ((D>1) && (2*O>D))
  495. struct ChooseComponents<D, O>
  496. {
  497. static_assert(D>=O, "Bad dimension or form order.");
  498. using type = typename MapAntiCombination<ChooseComponents_<D, D-O>, D>::type;
  499. };
  500. // Works almost to the range of std::size_t.
  501. constexpr std::size_t
  502. n_over_p(std::size_t const n, std::size_t p)
  503. {
  504. if (p>n) {
  505. return 0;
  506. } else if (p>(n-p)) {
  507. p = n-p;
  508. }
  509. std::size_t v = 1;
  510. for (std::size_t i=0; i!=p; ++i) {
  511. v = v*(n-i)/(i+1);
  512. }
  513. return v;
  514. }
  515. // We form the basis for the result (Cr) and split it in pieces for Oa and Ob; there are (D over Oa) ways. Then we see where and with which signs these pieces are in the bases for Oa (Ca) and Ob (Cb), and form the product.
  516. template <int D, int Oa, int Ob>
  517. struct Wedge
  518. {
  519. constexpr static int Or = Oa+Ob;
  520. static_assert(Oa<=D && Ob<=D && Or<=D, "bad orders");
  521. constexpr static int Na = n_over_p(D, Oa);
  522. constexpr static int Nb = n_over_p(D, Ob);
  523. constexpr static int Nr = n_over_p(D, Or);
  524. // in lexicographic order. Can be used to sort Ca below with FindPermutation.
  525. using LexOrCa = mp::combinations<mp::iota<D>, Oa>;
  526. // the actual components used, which are in lex. order only in some cases.
  527. using Ca = mp::ChooseComponents_<D, Oa>;
  528. using Cb = mp::ChooseComponents_<D, Ob>;
  529. using Cr = mp::ChooseComponents_<D, Or>;
  530. // optimizations.
  531. constexpr static bool yields_expr = (Na>1) != (Nb>1);
  532. constexpr static bool yields_expr_a1 = yields_expr && Na==1;
  533. constexpr static bool yields_expr_b1 = yields_expr && Nb==1;
  534. constexpr static bool both_scalars = (Na==1 && Nb==1);
  535. constexpr static bool dot_plus = Na>1 && Nb>1 && Or==D && (Oa<Ob || (Oa>Ob && !ra::odd(Oa*Ob)));
  536. constexpr static bool dot_minus = Na>1 && Nb>1 && Or==D && (Oa>Ob && ra::odd(Oa*Ob));
  537. constexpr static bool general_case = (Na>1 && Nb>1) && ((Oa+Ob!=D) || (Oa==Ob));
  538. template <class Va, class Vb>
  539. using valtype = std::decay_t<decltype(std::declval<Va>()[0] * std::declval<Vb>()[0])>;
  540. template <class Xr, class Fa, class Va, class Vb>
  541. constexpr static valtype<Va, Vb>
  542. term(Va const & a, Vb const & b)
  543. {
  544. if constexpr (mp::len<Fa> > 0) {
  545. using Fa0 = mp::first<Fa>;
  546. using Fb = mp::complement_list<Fa0, Xr>;
  547. using Sa = mp::FindCombination<Fa0, Ca>;
  548. using Sb = mp::FindCombination<Fb, Cb>;
  549. constexpr int sign = Sa::sign * Sb::sign * mp::PermutationSign<mp::append<Fa0, Fb>, Xr>::value;
  550. static_assert(sign==+1 || sign==-1, "Bad sign in wedge term.");
  551. return valtype<Va, Vb>(sign)*a[Sa::where]*b[Sb::where] + term<Xr, mp::drop1<Fa>>(a, b);
  552. } else {
  553. return 0.;
  554. }
  555. }
  556. template <class Va, class Vb, class Vr, int wr>
  557. constexpr static void
  558. coeff(Va const & a, Vb const & b, Vr & r)
  559. {
  560. if constexpr (wr<Nr) {
  561. using Xr = mp::ref<Cr, wr>;
  562. using Fa = mp::combinations<Xr, Oa>;
  563. r[wr] = term<Xr, Fa>(a, b);
  564. coeff<Va, Vb, Vr, wr+1>(a, b, r);
  565. }
  566. }
  567. template <class Va, class Vb, class Vr>
  568. constexpr static void
  569. product(Va const & a, Vb const & b, Vr & r)
  570. {
  571. static_assert(Va::size()==Na, "Bad Va dim.");
  572. static_assert(Vb::size()==Nb, "Bad Vb dim.");
  573. static_assert(Vr::size()==Nr, "Bad Vr dim.");
  574. coeff<Va, Vb, Vr, 0>(a, b, r);
  575. }
  576. };
  577. // Euclidean space, only component shuffling.
  578. template <int D, int O>
  579. struct Hodge
  580. {
  581. using W = Wedge<D, O, D-O>;
  582. using Ca = typename W::Ca;
  583. using Cb = typename W::Cb;
  584. using Cr = typename W::Cr;
  585. using LexOrCa = typename W::LexOrCa;
  586. constexpr static int Na = W::Na;
  587. constexpr static int Nb = W::Nb;
  588. template <int i, class Va, class Vb>
  589. constexpr static void
  590. hodge_aux(Va const & a, Vb & b)
  591. {
  592. static_assert(i<=W::Na, "Bad argument to hodge_aux");
  593. if constexpr (i<W::Na) {
  594. using Cai = mp::ref<Ca, i>;
  595. static_assert(mp::len<Cai> == O, "Bad.");
  596. // sort Cai, because mp::complement only accepts sorted combinations.
  597. // ref<Cb, i> should be complementary to Cai, but I don't want to rely on that.
  598. using SCai = mp::ref<LexOrCa, mp::FindCombination<Cai, LexOrCa>::where>;
  599. using CompCai = mp::complement<SCai, D>;
  600. static_assert(mp::len<CompCai> == D-O, "Bad.");
  601. using fpw = mp::FindCombination<CompCai, Cb>;
  602. // for the sign see e.g. DoCarmo1991 I.Ex 10.
  603. using fps = mp::FindCombination<mp::append<Cai, mp::ref<Cb, fpw::where>>, Cr>;
  604. static_assert(fps::sign!=0, "Bad.");
  605. b[fpw::where] = decltype(a[i])(fps::sign)*a[i];
  606. hodge_aux<i+1>(a, b);
  607. }
  608. }
  609. };
  610. // The order of components is taken from Wedge<D, O, D-O>; this works for whatever order is defined there.
  611. // With lexicographic order, component order is reversed, but signs vary.
  612. // With the order given by ChooseComponents<>, fpw::where==i and fps::sign==+1 in hodge_aux(), always. Then hodge() becomes a free operation, (with one exception) and the next function hodge() can be used.
  613. template <int D, int O, class Va, class Vb>
  614. constexpr void
  615. hodgex(Va const & a, Vb & b)
  616. {
  617. static_assert(O<=D, "bad orders");
  618. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  619. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  620. mp::Hodge<D, O>::template hodge_aux<0>(a, b);
  621. }
  622. } // namespace ra::mp
  623. // This depends on Wedge<>::Ca, Cb, Cr coming from ChooseCombinations. hodgex() should always work, but this is cheaper.
  624. // However if 2*O=D, it is not possible to differentiate the bases by order and hodgex() must be used.
  625. // Likewise, when O(N-O) is odd, Hodge from (2*O>D) to (2*O<D) change sign, since **w= -w in that case, and the basis in the (2*O>D) case is selected to make Hodge(<)->Hodge(>) trivial; but can't do both!
  626. consteval bool trivial_hodge(int D, int O) { return 2*O!=D && ((2*O<D) || !ra::odd(O*(D-O))); }
  627. template <int D, int O, class Va, class Vb>
  628. constexpr void
  629. hodge(Va const & a, Vb & b)
  630. {
  631. if constexpr (trivial_hodge(D, O)) {
  632. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  633. static_assert(Vb::size()==mp::Hodge<D, O>::Nb, "error");
  634. b = a;
  635. } else {
  636. ra::mp::hodgex<D, O>(a, b);
  637. }
  638. }
  639. template <int D, int O, class Va> requires (trivial_hodge(D, O))
  640. constexpr Va const &
  641. hodge(Va const & a)
  642. {
  643. static_assert(Va::size()==mp::Hodge<D, O>::Na, "error");
  644. return a;
  645. }
  646. template <int D, int O, class Va> requires (!trivial_hodge(D, O))
  647. constexpr Va &
  648. hodge(Va & a)
  649. {
  650. Va b(a);
  651. ra::mp::hodgex<D, O>(b, a);
  652. return a;
  653. }
  654. template <int D, int Oa, int Ob, class A, class B> requires (is_scalar<A> && is_scalar<B>)
  655. constexpr auto
  656. wedge(A const & a, B const & b) { return a*b; }
  657. template <class A>
  658. using torank1 = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
  659. template <int D, int Oa, int Ob, class Va, class Vb> requires (!(is_scalar<Va> && is_scalar<Vb>))
  660. decltype(auto)
  661. wedge(Va const & a, Vb const & b)
  662. {
  663. Small<value_t<Va>, size_s<Va>()> aa = a;
  664. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  665. using Ua = decltype(aa);
  666. using Ub = decltype(bb);
  667. using Wedge = mp::Wedge<D, Oa, Ob>;
  668. using valtype = typename Wedge::template valtype<Ua, Ub>;
  669. std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>> r;
  670. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  671. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  672. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  673. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  674. return r;
  675. }
  676. template <int D, int Oa, int Ob, class Va, class Vb, class Vr> requires (!(is_scalar<Va> && is_scalar<Vb>))
  677. void
  678. wedge(Va const & a, Vb const & b, Vr & r)
  679. {
  680. Small<value_t<Va>, size_s<Va>()> aa = a;
  681. Small<value_t<Vb>, size_s<Vb>()> bb = b;
  682. using Ua = decltype(aa);
  683. using Ub = decltype(bb);
  684. auto & r1 = reinterpret_cast<torank1<decltype(r)> &>(r);
  685. auto & a1 = reinterpret_cast<torank1<Ua> const &>(aa);
  686. auto & b1 = reinterpret_cast<torank1<Ub> const &>(bb);
  687. mp::Wedge<D, Oa, Ob>::product(a1, b1, r1);
  688. }
  689. template <class A, class B>
  690. constexpr auto
  691. cross(A const & a_, B const & b_)
  692. {
  693. constexpr int n = size_s<A>();
  694. static_assert(n==size_s<B>() && (2==n || 3==n));
  695. Small<std::decay_t<decltype(VALUE(a_))>, n> a = a_;
  696. Small<std::decay_t<decltype(VALUE(b_))>, n> b = b_;
  697. using W = mp::Wedge<n, 1, 1>;
  698. Small<std::decay_t<decltype(VALUE(a_) * VALUE(b_))>, W::Nr> r;
  699. W::product(a, b, r);
  700. if constexpr (1==W::Nr) {
  701. return r[0];
  702. } else {
  703. return r;
  704. }
  705. }
  706. template <class V>
  707. constexpr auto
  708. perp(V const & v)
  709. {
  710. static_assert(2==v.size(), "Dimension error.");
  711. return Small<std::decay_t<decltype(VALUE(v))>, 2> {v[1], -v[0]};
  712. }
  713. template <class V, class U>
  714. constexpr auto
  715. perp(V const & v, U const & n)
  716. {
  717. if constexpr (is_scalar<U>) {
  718. static_assert(2==v.size(), "Dimension error.");
  719. return Small<std::decay_t<decltype(VALUE(v) * n)>, 2> {v[1]*n, -v[0]*n};
  720. } else {
  721. static_assert(3==v.size(), "Dimension error.");
  722. return cross(v, n);
  723. }
  724. }
  725. } // namespace ra
  726. #undef RA_OPT