123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- // -*- mode: c++; coding: utf-8 -*-
- /// @file wrank.H
- /// @brief Rank conjunction for expression templates.
- // (c) Daniel Llorens - 2013-2017, 2019
- // This library is free software; you can redistribute it and/or modify it under
- // the terms of the GNU Lesser General Public License as published by the Free
- // Software Foundation; either version 3 of the License, or (at your option) any
- // later version.
- #pragma once
- #include "ra/expr.H"
- #if defined(RA_CHECK_BOUNDS) && RA_CHECK_BOUNDS==0
- #define CHECK_BOUNDS( cond )
- #else
- #define CHECK_BOUNDS( cond ) RA_ASSERT( cond, 0 )
- #endif
- // TODO Adopt frame matching as in Match (no driver).
- namespace ra {
- template <class cranks, class Op>
- struct Verb
- {
- using R = cranks;
- Op op;
- };
- template <class cranks, class Op> inline constexpr
- auto wrank(cranks cranks_, Op && op)
- {
- return Verb<cranks, Op> { std::forward<Op>(op) };
- }
- template <rank_t ... crank, class Op> inline constexpr
- auto wrank(Op && op)
- {
- return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
- }
- template <class A> using ValidRank = mp::int_t<(A::value>=0)>;
- template <class R, class skip, class frank>
- using AddFrameAxes = mp::append<R, mp::iota<frank::value, skip::value>>;
- template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
- struct Framematch_def;
- template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
- using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
- // FIXME Replace the frame matching mechanism by driverless, per-axis Match
- template <class A, class B>
- struct max_i
- {
- constexpr static int value = gt_rank(A::value, B::value) ? 0 : 1; // 0 if ra wins, else 1
- };
- template <class T> using largest_i_tuple = mp::IndexOf<max_i, T>;
- // Get a list (per argument) of lists of live axes.
- // The last frame match is not done; that relies on rest axis handling of each argument (ignoring axis spec beyond their own rank). TODO Reexamine that.
- // Case where V has rank.
- template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
- struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
- {
- static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "bad args");
- using T = std::tuple<Ti ...>;
- using R_ = std::tuple<Ri ...>;
- // TODO functions of arg rank, negative, inf.
- // live = number of live axes on this frame, for each argument.
- using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>-crank::value) ...>;
- static_assert(mp::apply<mp::andb, mp::map<ValidRank, live>>::value, "bad ranks");
- // select driver for this stage.
- constexpr static int driver = largest_i_tuple<live>::value;
- // add actual axes to result.
- using skips = mp::makelist<sizeof...(Ti), mp::int_t<skip>>;
- using FM = Framematch<W, T, mp::map<AddFrameAxes, R_, skips, live>,
- skip + mp::ref<live, driver>::value>;
- using R = typename FM::R;
- constexpr static int depth = mp::ref<live, driver>::value + FM::depth;
- // drill down in V to get innermost Op (cf [ra31]).
- template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); }
- };
- // Terminal case where V doesn't have rank (is a raw op()).
- template <class V, class ... Ti, class ... Ri, rank_t skip>
- struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
- {
- using R_ = std::tuple<Ri ...>;
- // TODO -crank::value when the actual verb rank is used (e.g. to use cell_iterator<A, that_rank> instead of just begin()).
- using live = mp::int_list<(std::decay_t<Ti>::rank_s()-mp::len<Ri>) ...>;
- static_assert(mp::apply<mp::andb, mp::map<ValidRank, live>>::value, "bad ranks");
- constexpr static int driver = largest_i_tuple<live>::value;
- using skips = mp::makelist<sizeof...(Ti), mp::int_t<skip>>;
- using R = mp::map<AddFrameAxes, R_, skips, live>;
- constexpr static int depth = mp::ref<live, driver>::value;
- template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
- };
- template <class T>
- struct zerostride
- {
- constexpr static T f() { return T(0); }
- };
- template <class ... T>
- struct zerostride<std::tuple<T ...>>
- {
- constexpr static std::tuple<T ...> f() { return std::make_tuple(zerostride<T>::f() ...); }
- };
- // Wraps each argument of an expression using wrank.
- template <class LiveAxes, int depth, class A>
- struct ApplyFrames
- {
- A a;
- constexpr static int live(int k) { return mp::int_list_index<LiveAxes>(k); }
- template <class I>
- constexpr decltype(auto) at(I const & i)
- {
- return a.at(mp::map_indices<std::array<dim_t, mp::len<LiveAxes>>, LiveAxes>(i));
- }
- constexpr dim_t size(int k) const
- {
- int l = live(k);
- return l>=0 ? a.size(l) : DIM_BAD;
- }
- constexpr void adv(rank_t k, dim_t d)
- {
- if (int l = live(k); l>=0) {
- a.adv(l, d);
- }
- }
- constexpr auto stride(int k) const
- {
- int l = live(k);
- return l>=0 ? a.stride(l) : zerostride<decltype(a.stride(l))>::f();
- }
- constexpr bool keep_stride(dim_t step, int z, int j) const
- {
- int wz = live(z);
- int wj = live(j);
- return wz>=0 && wj>=0 && a.keep_stride(step, wz, wj);
- }
- constexpr decltype(auto) flat() { return a.flat(); }
- constexpr static dim_t size_s(int k)
- {
- int l = live(k);
- return l>=0 ? std::decay_t<A>::size_s(l) : DIM_BAD;
- }
- constexpr static rank_t rank() { return depth; } // TODO Invalid for RANK_ANY [ra07]
- constexpr static rank_t rank_s() { return depth; } // TODO Invalid for RANK_ANY [ra07]
- };
- // No-op case. TODO Maybe apply to any Iota<n> where n<=depth.
- // TODO If A is cell_iterator, etc. beat LiveAxes directly on that... same for an eventual transpose_expr<>.
- template <class LiveAxes, int depth, class A>
- decltype(auto) applyframes(A && a)
- {
- if constexpr (std::is_same_v<LiveAxes, mp::iota<depth>>) {
- return std::forward<A>(a);
- } else {
- return ApplyFrames<LiveAxes, depth, A> { std::forward<A>(a) };
- }
- }
- template <class V, class ... T, int ... i> inline constexpr
- auto ryn(mp::int_list<i ...>, V && v, T && ... t)
- {
- using FM = Framematch<V, std::tuple<T ...>>;
- return expr(FM::op(std::forward<V>(v)),
- applyframes<mp::ref<typename FM::R, i>, FM::depth>(std::forward<T>(t)) ...);
- }
- // TODO partial specialization means no universal ref :-/
- #define DEF_EXPR_VERB(MOD) \
- template <class cranks, class Op, class ... P> inline constexpr \
- auto expr(Verb<cranks, Op> MOD v, P && ... t) \
- { \
- return ryn(mp::iota<sizeof...(P)> {}, std::forward<decltype(v)>(v), std::forward<P>(t) ...); \
- }
- FOR_EACH(DEF_EXPR_VERB, &&, &, const &)
- #undef DEF_EXPR_VERB
- // ---------------------------
- // from, after APL, like (from) in guile-ploy
- // TODO integrate with is_beatable shortcuts, operator() in the various array types.
- // ---------------------------
- template <class I>
- struct index_rank_
- {
- using type = mp::int_t<std::decay_t<I>::rank_s()>; // see ra_traits for ra::types (?)
- static_assert(type::value!=RANK_ANY, "dynamic rank unsupported");
- static_assert(size_s<I>()!=DIM_BAD, "undelimited extent subscript unsupported");
- };
- template <class I> using index_rank = typename index_rank_<I>::type;
- template <class II, int drop, class Op> inline constexpr
- decltype(auto) from_partial(Op && op)
- {
- if constexpr (drop==mp::len<II>) {
- return std::forward<Op>(op);
- } else {
- return wrank(mp::append<mp::makelist<drop, mp::int_t<0>>, mp::drop<II, drop>> {},
- from_partial<II, drop+1>(std::forward<Op>(op)));
- }
- }
- // TODO we should be able to do better by slicing at each dimension, etc. But verb<> only supports rank-0 for the innermost op.
- template <class A, class ... I> inline constexpr
- auto from(A && a, I && ... i)
- {
- if constexpr (0==sizeof...(i)) {
- return a();
- } else if constexpr (1==sizeof...(i)) {
- // support dynamic rank for 1 arg only (see test in test/from.C).
- return expr(std::forward<A>(a), start(std::forward<I>(i) ...));
- } else {
- using II = mp::map<index_rank, mp::tuple<decltype(start(std::forward<I>(i))) ...>>;
- return expr(from_partial<II, 1>(std::forward<A>(a)), start(std::forward<I>(i)) ...);
- }
- }
- } // namespace ra
- #undef CHECK_BOUNDS
|