123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502 |
- // -*- mode: c++; coding: utf-8 -*-
- /// @file operators.H
- /// @brief Sugar for ra:: expression templates.
- // (c) Daniel Llorens - 2014-2017
- // This library is free software; you can redistribute it and/or modify it under
- // the terms of the GNU General Public License as published by the Free
- // Software Foundation; either version 3 of the License, or (at your option) any
- // later version.
- #pragma once
- // FIXME Dependence on specific ra:: types should maybe be elsewhere.
- #include "ra/global.H"
- #include "ra/complex.H"
- #include "ra/wrank.H"
- #include "ra/pick.H"
- #include "ra/view-ops.H"
- #if defined(RA_CHECK_BOUNDS) && RA_CHECK_BOUNDS==0
- #define CHECK_BOUNDS( cond )
- #else
- #define CHECK_BOUNDS( cond ) RA_ASSERT( cond, 0 )
- #endif
- #include "ra/optimize.H"
- #ifndef RA_OPTIMIZE
- #define RA_OPTIMIZE 1 // enabled by default
- #endif
- #if RA_OPTIMIZE==1
- #define OPTIMIZE optimize
- #else
- #define OPTIMIZE
- #endif
- namespace ra {
- template <int ... Iarg, class A> inline constexpr
- decltype(auto) transpose(mp::int_list<Iarg ...>, A && a)
- {
- return transpose<Iarg ...>(std::forward<A>(a));
- }
- namespace {
- template <class A> inline constexpr
- decltype(auto) FLAT(A && a)
- {
- return *(ra::start(std::forward<A>(a)).flat());
- }
- } // namespace
- // I considered three options for lookup.
- // 1. define these in a class that ArrayIterator or Container or Slice types derive from. This was done for an old library I had (vector-ops.H). It results in the narrowest scope, but since those types are used in the definition (ra::Expr is an ArrayIterator), it requires lots of forwarding and traits:: .
- // 2. raw ADL doesn't work because some ra:: types use ! != etc for different things (e.g. Flat). Possible solution: don't ever use + != == for Flat.
- // 3. enable_if'd ADL is what you see here.
- // --------------------------------
- // Array versions of operators and functions
- // --------------------------------
- // We need the zero/scalar specializations because the scalar/scalar operators
- // maybe be templated (e.g. complex<>), so they won't be found when an implicit
- // conversion from zero->scalar is also needed. That is, without those
- // specializations, ra::View<complex, 0> * complex will fail.
- // These depend on OPNAME defined in optimize.H and used there to match ET patterns.
- #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
- template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
- inline constexpr auto operator OP(A && a, B && b) \
- { \
- return OPTIMIZE(map(OPNAME(), std::forward<A>(a), std::forward<B>(b))); \
- } \
- template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
- inline constexpr auto operator OP(A && a, B && b) \
- { \
- return FLAT(a) OP FLAT(b); \
- }
- DEF_NAMED_BINARY_OP(+, plus)
- DEF_NAMED_BINARY_OP(-, minus)
- DEF_NAMED_BINARY_OP(*, times)
- DEF_NAMED_BINARY_OP(/, slash)
- #undef DEF_NAMED_BINARY_OP
- #define DEF_BINARY_OP(OP) \
- template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0> \
- inline auto operator OP(A && a, B && b) \
- { \
- return map([](auto && a, auto && b) { return a OP b; }, \
- std::forward<A>(a), std::forward<B>(b)); \
- } \
- template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
- inline auto operator OP(A && a, B && b) \
- { \
- return FLAT(a) OP FLAT(b); \
- }
- FOR_EACH(DEF_BINARY_OP, >, <, >=, <=, ==, !=, |, &, ^)
- #undef DEF_BINARY_OP
- #define DEF_UNARY_OP(OP) \
- template <class A, std::enable_if_t<is_ra_pos_rank<A>, int> =0> \
- inline auto operator OP(A && a) \
- { \
- return map([](auto && a) { return OP a; }, std::forward<A>(a)); \
- }
- FOR_EACH(DEF_UNARY_OP, !, +, -) // TODO Make + into nop.
- #undef DEF_UNARY_OP
- // When OP(a) isn't found from ra::, the deduction from rank(0) -> scalar doesn't work.
- // TODO Cf [ref:examples/useret.C:0].
- #define DEF_NAME_OP(OP) \
- using ::OP; \
- template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
- inline auto OP(A && ... a) \
- { \
- return map([](auto && ... a) { return OP(a ...); }, std::forward<A>(a) ...); \
- } \
- template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
- inline auto OP(A && ... a) \
- { \
- return OP(FLAT(a) ...); \
- }
- FOR_EACH(DEF_NAME_OP, rel_error, pow, xI, conj, sqr, sqrm, sqrt, cos, sin)
- FOR_EACH(DEF_NAME_OP, exp, expm1, log, log1p, log10, isfinite, isnan, isinf)
- FOR_EACH(DEF_NAME_OP, max, min, abs, odd, asin, acos, atan, atan2, clamp)
- FOR_EACH(DEF_NAME_OP, cosh, sinh, tanh, arg)
- #undef DEF_NAME_OP
- #define DEF_NAME_OP(OP) \
- using ::OP; \
- template <class ... A, std::enable_if_t<ra_pos_and_any<A ...>, int> =0> \
- inline auto OP(A && ... a) \
- { \
- return map([](auto && ... a) -> decltype(auto) { return OP(a ...); }, std::forward<A>(a) ...); \
- } \
- template <class ... A, std::enable_if_t<ra_zero<A ...>, int> =0> \
- inline decltype(auto) OP(A && ... a) \
- { \
- return OP(FLAT(a) ...); \
- }
- FOR_EACH(DEF_NAME_OP, real_part, imag_part)
- #undef DEF_NAME_OP
- template <class T, class A>
- inline auto cast(A && a)
- {
- return map([](auto && a) { return T(a); }, std::forward<A>(a));
- }
- // TODO could be useful to deduce T as tuple of value_types (&).
- template <class T, class ... A>
- inline auto pack(A && ... a)
- {
- return map([](auto && ... a) { return T { a ... }; }, std::forward<A>(a) ...);
- }
- // FIXME Inelegant story wrt plain array / nested array :-/
- template <class A, class I>
- inline auto at(A && a, I && i)
- {
- return map([&a](auto && i) -> decltype(auto) { return a.at(i); }, i);
- }
- template <class W, class T, class F,
- std::enable_if_t<is_ra_pos_rank<W> || is_ra_pos_rank<T> || is_ra_pos_rank<F>, int> =0>
- inline auto where(W && w, T && t, F && f)
- {
- return pick(cast<bool>(start(std::forward<W>(w))), start(std::forward<F>(f)), start(std::forward<T>(t)));
- }
- template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0>
- inline auto operator &&(A && a, B && b)
- {
- return where(std::forward<A>(a), cast<bool>(std::forward<B>(b)), false);
- }
- template <class A, class B, std::enable_if_t<ra_pos_and_any<A, B>, int> =0>
- inline auto operator ||(A && a, B && b)
- {
- return where(std::forward<A>(a), true, cast<bool>(std::forward<B>(b)));
- }
- #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
- template <class A, class B, std::enable_if_t<ra_zero<A, B>, int> =0> \
- inline auto operator OP(A && a, B && b) \
- { \
- return FLAT(a) OP FLAT(b); \
- }
- FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
- #undef DEF_SHORTCIRCUIT_BINARY_OP
- // --------------------------------
- // Some whole-array reductions.
- // TODO First rank reductions? Variable rank reductions?
- // --------------------------------
- template <class A> inline bool
- any(A && a)
- {
- return early(map([](bool x) { return std::make_tuple(x, x); }, std::forward<A>(a)), false);
- }
- template <class A> inline bool
- every(A && a)
- {
- return early(map([](bool x) { return std::make_tuple(!x, x); }, std::forward<A>(a)), true);
- }
- // FIXME variable rank? see J 'index of' (x i. y), etc.
- template <class A>
- inline auto index(A && a)
- {
- return early(map([](auto && a, auto && i) { return std::make_tuple(bool(a), i); },
- std::forward<A>(a), ra::iota(start(a).size(0))),
- ra::dim_t(-1));
- }
- // [ma108]
- template <class A, class B>
- inline bool lexicographical_compare(A && a, B && b)
- {
- return early(map([](auto && a, auto && b)
- { return a==b ? std::make_tuple(false, true) : std::make_tuple(true, a<b); },
- a, b),
- false);
- }
- // FIXME only works with numeric types.
- using std::min;
- template <class A>
- inline auto amin(A && a)
- {
- using T = value_t<A>;
- T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
- for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
- return c;
- }
- using std::max;
- template <class A>
- inline auto amax(A && a)
- {
- using T = value_t<A>;
- T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
- for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
- return c;
- }
- // FIXME encapsulate this kind of reference-reduction.
- // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
- template <class A, class Less = std::less<value_t<A>>>
- inline decltype(auto) refmin(A && a, Less && less = std::less<value_t<A>>())
- {
- CHECK_BOUNDS(a.size()>0);
- decltype(auto) s = ra::start(a);
- auto p = &(*s.flat());
- for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
- return *p;
- }
- template <class A, class Less = std::less<value_t<A>>>
- inline decltype(auto) refmax(A && a, Less && less = std::less<value_t<A>>())
- {
- CHECK_BOUNDS(a.size()>0);
- decltype(auto) s = ra::start(a);
- auto p = &(*s.flat());
- for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
- return *p;
- }
- template <class A>
- inline constexpr auto sum(A && a)
- {
- value_t<A> c {};
- for_each([&c](auto && a) { c += a; }, a);
- return c;
- }
- template <class A>
- inline constexpr auto prod(A && a)
- {
- value_t<A> c(1.);
- for_each([&c](auto && a) { c *= a; }, a);
- return c;
- }
- template <class A> inline auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
- template <class A> inline auto norm2(A && a) { return std::sqrt(reduce_sqrm(a)); }
- template <class A, class B>
- inline auto dot(A && a, B && b)
- {
- std::decay_t<decltype(FLAT(a) * FLAT(b))> c(0.);
- for_each([&c](auto && a, auto && b) { c = fma(a, b, c); }, a, b);
- return c;
- }
- template <class A, class B>
- inline auto cdot(A && a, B && b)
- {
- std::decay_t<decltype(conj(FLAT(a)) * FLAT(b))> c(0.);
- for_each([&c](auto && a, auto && b) { c = fma_conj(a, b, c); }, a, b);
- return c;
- }
- // --------------------
- // Wedge product
- // TODO Handle the simplifications dot_plus, yields_scalar, etc. just as vec::wedge does.
- // --------------------
- template <class A>
- struct torank1
- {
- using type = std::conditional_t<is_scalar<A>, Small<std::decay_t<A>, 1>, A>;
- };
- template <class Wedge, class Va, class Vb, class Enable=void>
- struct fromrank1
- {
- using valtype = typename Wedge::template valtype<Va, Vb>;
- using type = std::conditional_t<Wedge::Nr==1, valtype, Small<valtype, Wedge::Nr>>;
- };
- #define DECL_WEDGE(condition) \
- template <int D, int Oa, int Ob, class Va, class Vb, \
- std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
- decltype(auto) wedge(Va const & a, Vb const & b)
- DECL_WEDGE(general_case)
- {
- Small<std::decay_t<decltype(FLAT(a))>, size_s(a)> aa = a;
- Small<std::decay_t<decltype(FLAT(b))>, size_s(b)> bb = b;
- using Ua = decltype(aa);
- using Ub = decltype(bb);
- typename fromrank1<fun::Wedge<D, Oa, Ob>, Ua, Ub>::type r;
- auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
- auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
- auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
- fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
- return r;
- }
- #undef DECL_WEDGE
- #define DECL_WEDGE(condition) \
- template <int D, int Oa, int Ob, class Va, class Vb, class Vr, \
- std::enable_if_t<!(is_scalar<Va> && is_scalar<Vb>), int> =0> \
- void wedge(Va const & a, Vb const & b, Vr & r)
- DECL_WEDGE(general_case)
- {
- Small<std::decay_t<decltype(FLAT(a))>, size_s(a)> aa = a;
- Small<std::decay_t<decltype(FLAT(b))>, size_s(b)> bb = b;
- using Ua = decltype(aa);
- using Ub = decltype(bb);
- auto & r1 = reinterpret_cast<typename torank1<decltype(r)>::type &>(r);
- auto & a1 = reinterpret_cast<typename torank1<Ua>::type const &>(aa);
- auto & b1 = reinterpret_cast<typename torank1<Ub>::type const &>(bb);
- fun::Wedge<D, Oa, Ob>::product(a1, b1, r1);
- }
- #undef DECL_WEDGE
- template <class A, class B, std::enable_if_t<size_s<A>()==2 && size_s<B>()==2, int> =0>
- inline auto cross(A const & a_, B const & b_)
- {
- Small<std::decay_t<decltype(FLAT(a_))>, 2> a = a_;
- Small<std::decay_t<decltype(FLAT(b_))>, 2> b = b_;
- Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 1> r;
- fun::Wedge<2, 1, 1>::product(a, b, r);
- return r[0];
- }
- template <class A, class B, std::enable_if_t<size_s<A>()==3 && size_s<B>()==3, int> =0>
- inline auto cross(A const & a_, B const & b_)
- {
- Small<std::decay_t<decltype(FLAT(a_))>, 3> a = a_;
- Small<std::decay_t<decltype(FLAT(b_))>, 3> b = b_;
- Small<std::decay_t<decltype(FLAT(a_) * FLAT(b_))>, 3> r;
- fun::Wedge<3, 1, 1>::product(a, b, r);
- return r;
- }
- template <class V>
- inline auto perp(V const & v)
- {
- static_assert(v.size()==2, "dimension error");
- return Small<std::decay_t<decltype(FLAT(v))>, 2> {v[1], -v[0]};
- }
- template <class V, class U, std::enable_if_t<is_scalar<U>, int> =0>
- inline auto perp(V const & v, U const & n)
- {
- static_assert(v.size()==2, "dimension error");
- return Small<std::decay_t<decltype(FLAT(v) * n)>, 2> {v[1]*n, -v[0]*n};
- }
- template <class V, class U, std::enable_if_t<!is_scalar<U>, int> =0>
- inline auto perp(V const & v, U const & n)
- {
- static_assert(v.size()==3, "dimension error");
- return cross(v, n);
- }
- // --------------------
- // Other whole-array ops.
- // --------------------
- template <class A, std::enable_if_t<is_slice<A>, int> =0>
- inline auto normv(A const & a)
- {
- return concrete(a/norm2(a));
- }
- template <class A, std::enable_if_t<!is_slice<A> && is_ra<A>, int> =0>
- inline auto normv(A const & a)
- {
- auto b = concrete(a);
- b /= norm2(b);
- return b;
- }
- // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
- template <class A, class B, class C>
- inline void
- gemm(A const & a, B const & b, C & c)
- {
- for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
- }
- // FIXME branch gemm on Ryn::size_s(), but that's bugged.
- #define MMTYPE decltype(from(times(), a(ra::all, 0), b(0, ra::all)))
- // default for row-major x row-major. See bench-gemm.C for variants.
- template <class S, class T>
- inline auto
- gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
- {
- int const M = a.size(0);
- int const N = b.size(1);
- int const K = a.size(1);
- // no with_same_shape b/c cannot index 0 for type if A/B are empty
- auto c = with_shape<MMTYPE>({M, N}, decltype(a(0, 0)*b(0, 0))());
- for (int k=0; k<K; ++k) {
- c += from(times(), a(ra::all, k), b(k, ra::all));
- }
- return c;
- }
- // we still want the Small version to be different.
- template <class A, class B>
- inline ra::Small<std::decay_t<decltype(FLAT(std::declval<A>()) * FLAT(std::declval<B>()))>, A::size(0), B::size(1)>
- gemm(A const & a, B const & b)
- {
- constexpr int M = a.size(0);
- constexpr int N = b.size(1);
- // no with_same_shape b/c cannot index 0 for type if A/B are empty
- auto c = with_shape<MMTYPE>({M, N}, ra::none);
- for (int i=0; i<M; ++i) {
- for (int j=0; j<N; ++j) {
- c(i, j) = dot(a(i), b(ra::all, j));
- }
- }
- return c;
- }
- #undef MMTYPE
- template <class A, class B>
- inline auto
- gevm(A const & a, B const & b)
- {
- int const M = b.size(0);
- int const N = b.size(1);
- // no with_same_shape b/c cannot index 0 for type if A/B are empty
- auto c = with_shape<decltype(a[0]*b(0, ra::all))>({N}, 0);
- for (int i=0; i<M; ++i) {
- c += a[i]*b(i);
- }
- return c;
- }
- // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
- template <class A, class B>
- inline auto
- gemv(A const & a, B const & b)
- {
- int const M = a.size(0);
- int const N = a.size(1);
- // no with_same_shape b/c cannot index 0 for type if A/B are empty
- auto c = with_shape<decltype(a(ra::all, 0)*b[0])>({M}, 0);
- for (int j=0; j<N; ++j) {
- c += a(ra::all, j) * b[j];
- }
- return c;
- }
- } // namespace ra
- #undef OPTIMIZE
- #undef CHECK_BOUNDS
|