view-ops.hh 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file view-ops.hh
  3. /// @brief Operations specific to Views
  4. // (c) Daniel Llorens - 2013-2014, 2017
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #pragma once
  10. #include "ra/concrete.hh"
  11. #include <complex>
  12. namespace ra {
  13. template <class P, rank_t RANK> inline
  14. View<P, RANK> reverse(View<P, RANK> const & view, int k)
  15. {
  16. View<P, RANK> r = view;
  17. auto & dim = r.dim[k];
  18. if (dim.size!=0) {
  19. r.p += dim.stride*(dim.size-1);
  20. dim.stride *= -1;
  21. }
  22. return r;
  23. }
  24. // dynamic transposed axes list.
  25. template <class P, rank_t RANK, class S> inline
  26. View<P, RANK_ANY> transpose_(S && s, View<P, RANK> const & view)
  27. {
  28. RA_CHECK(view.rank()==dim_t(ra::size(s)));
  29. auto rp = std::max_element(s.begin(), s.end());
  30. rank_t dstrank = (rp==s.end() ? 0 : *rp+1);
  31. View<P, RANK_ANY> r { decltype(r.dim)(dstrank, Dim { DIM_BAD, 0 }), view.data() };
  32. for (int k=0; int sk: s) {
  33. Dim & dest = r.dim[sk];
  34. dest.stride += view.dim[k].stride;
  35. dest.size = dest.size>=0 ? std::min(dest.size, view.dim[k].size) : view.dim[k].size;
  36. ++k;
  37. }
  38. return r;
  39. }
  40. template <class P, rank_t RANK, class S> inline
  41. View<P, RANK_ANY> transpose(S && s, View<P, RANK> const & view)
  42. {
  43. return transpose_(std::forward<S>(s), view);
  44. }
  45. // Note that we need the compile time values and not the sizes to deduce the rank of the output, so it would be useless to provide a builtin array shim as we do with reshape().
  46. template <class P, rank_t RANK> inline
  47. View<P, RANK_ANY> transpose(std::initializer_list<ra::rank_t> s, View<P, RANK> const & view)
  48. {
  49. return transpose_(s, view);
  50. }
  51. // static transposed axes list.
  52. template <int ... Iarg, class P, rank_t RANK> inline
  53. auto transpose(View<P, RANK> const & view)
  54. {
  55. static_assert(RANK==RANK_ANY || RANK==sizeof...(Iarg), "bad output rank");
  56. RA_CHECK((view.rank()==sizeof...(Iarg)) && "bad output rank");
  57. using dummy_s = mp::makelist<sizeof...(Iarg), mp::int_t<0>>;
  58. using ti = axes_list_indices<mp::int_list<Iarg ...>, dummy_s, dummy_s>;
  59. constexpr rank_t DSTRANK = mp::len<typename ti::dst>;
  60. View<P, DSTRANK> r { decltype(r.dim)(Dim { DIM_BAD, 0 }), view.data() };
  61. std::array<int, sizeof...(Iarg)> s {{ Iarg ... }};
  62. for (int k=0; int sk: s) {
  63. Dim & dest = r.dim[sk];
  64. dest.stride += view.dim[k].stride;
  65. dest.size = dest.size>=0 ? std::min(dest.size, view.dim[k].size) : view.dim[k].size;
  66. ++k;
  67. }
  68. return r;
  69. }
  70. template <class P, rank_t RANK> inline
  71. auto diag(View<P, RANK> const & view)
  72. {
  73. return transpose<0, 0>(view);
  74. }
  75. template <class P, rank_t RANK> inline
  76. bool is_ravel_free(View<P, RANK> const & a)
  77. {
  78. int r = a.rank()-1;
  79. for (; r>=0 && a.size(r)==1; --r) {}
  80. if (r<0) { return true; }
  81. ra::dim_t s = a.stride(r)*a.size(r);
  82. while (--r>=0) {
  83. if (1!=a.size(r)) {
  84. if (a.stride(r)!=s) {
  85. return false;
  86. }
  87. s *= a.size(r);
  88. }
  89. }
  90. return true;
  91. }
  92. template <class P, rank_t RANK> inline
  93. View<P, 1> ravel_free(View<P, RANK> const & a)
  94. {
  95. RA_CHECK(is_ravel_free(a));
  96. int r = a.rank()-1;
  97. for (; r>=0 && a.size(r)==1; --r) {}
  98. ra::dim_t s = r<0 ? 1 : a.stride(r);
  99. return ra::View<P, 1>({{size(a), s}}, a.p);
  100. }
  101. template <class P, rank_t RANK, class S> inline
  102. auto reshape_(View<P, RANK> const & a, S && sb_)
  103. {
  104. auto sb = concrete(std::forward<S>(sb_));
  105. // FIXME when we need to copy, accept/return Shared
  106. dim_t la = ra::size(a);
  107. dim_t lb = 1;
  108. for (int i=0; i<ra::size(sb); ++i) {
  109. if (sb[i]==-1) {
  110. dim_t quot = lb;
  111. for (int j=i+1; j<ra::size(sb); ++j) {
  112. quot *= sb[j];
  113. RA_CHECK(quot>0 && "cannot deduce dimensions");
  114. }
  115. auto pv = la/quot;
  116. RA_CHECK((la%quot==0 && pv>=0) && "bad placeholder");
  117. sb[i] = pv;
  118. lb = la;
  119. break;
  120. } else {
  121. lb *= sb[i];
  122. }
  123. }
  124. auto sa = shape(a);
  125. // FIXME should be able to reshape Scalar etc.
  126. View<P, ra::size_s(sb)> b(map([](auto i) { return Dim { DIM_BAD, 0 }; }, ra::iota(ra::size(sb))), a.data());
  127. rank_t i = 0;
  128. for (; i<a.rank() && i<b.rank(); ++i) {
  129. if (sa[a.rank()-i-1]!=sb[b.rank()-i-1]) {
  130. assert(is_ravel_free(a) && "reshape w/copy not implemented");
  131. if (la>=lb) {
  132. // FIXME View(SS const & s, P p). Cf [ra37].
  133. for_each([](auto & dim, auto && s) { dim.size = s; }, b.dim, sb);
  134. filldim(b.dim.size(), b.dim.end());
  135. for (int j=0; j!=b.rank(); ++j) {
  136. b.dim[j].stride *= a.stride(a.rank()-1);
  137. }
  138. return b;
  139. } else {
  140. assert(0 && "reshape case not implemented");
  141. }
  142. } else {
  143. // select
  144. b.dim[b.rank()-i-1] = a.dim[a.rank()-i-1];
  145. }
  146. }
  147. if (i==a.rank()) {
  148. // tile & return
  149. for (rank_t j=i; j<b.rank(); ++j) {
  150. b.dim[b.rank()-j-1] = { sb[b.rank()-j-1], 0 };
  151. }
  152. }
  153. return b;
  154. }
  155. template <class P, rank_t RANK, class S> inline
  156. auto reshape(View<P, RANK> const & a, S && sb_)
  157. {
  158. return reshape_(a, std::forward<S>(sb_));
  159. }
  160. // We need dimtype b/c {1, ...} deduces to int and that fails to match ra::dim_t.
  161. // We could use initializer_list to handle the general case, but that would produce a var rank result because its size cannot be deduced at compile time :-/. Unfortunately an initializer_list specialization would override this one, so we cannot provide it as a fallback.
  162. template <class P, rank_t RANK, class dimtype, int N> inline
  163. auto reshape(View<P, RANK> const & a, dimtype const (&sb_)[N])
  164. {
  165. return reshape_(a, sb_);
  166. }
  167. // lo = lower bounds, hi = upper bounds.
  168. // The stencil indices are in [0 lo+1+hi] = [-lo +hi].
  169. template <class LO, class HI, class P, rank_t N> inline
  170. View<P, rank_sum(N, N)>
  171. stencil(View<P, N> const & a, LO && lo, HI && hi)
  172. {
  173. View<P, rank_sum(N, N)> s;
  174. s.p = a.data();
  175. ra::resize(s.dim, 2*a.rank());
  176. RA_CHECK(every(lo>=0));
  177. RA_CHECK(every(hi>=0));
  178. for_each([](auto & dims, auto && dima, auto && lo, auto && hi)
  179. {
  180. RA_CHECK(dima.size>=lo+hi && "stencil is too large for array");
  181. dims = {dima.size-lo-hi, dima.stride};
  182. },
  183. ptr(s.dim.data()), a.dim, lo, hi);
  184. for_each([](auto & dims, auto && dima, auto && lo, auto && hi)
  185. { dims = {lo+hi+1, dima.stride}; },
  186. ptr(s.dim.data()+a.rank()), a.dim, lo, hi);
  187. return s;
  188. }
  189. // Make last sizes of View<> be compile-time constants.
  190. template <class super_t, rank_t SUPERR, class T, rank_t RANK> inline
  191. auto explode_(View<T *, RANK> const & a)
  192. {
  193. // TODO Reduce to single check, either the first or the second.
  194. static_assert(RANK>=SUPERR || RANK==RANK_ANY, "rank of a is too low");
  195. RA_CHECK(a.rank()>=SUPERR && "rank of a is too low");
  196. View<super_t *, rank_sum(RANK, -SUPERR)> b;
  197. ra::resize(b.dim, a.rank()-SUPERR);
  198. dim_t r = 1;
  199. for (int i=0; i<SUPERR; ++i) {
  200. r *= a.size(i+b.rank());
  201. }
  202. RA_CHECK(r*sizeof(T)==sizeof(super_t) && "size of SUPERR axes doesn't match super type");
  203. for (int i=0; i<b.rank(); ++i) {
  204. RA_CHECK(a.stride(i) % r==0 && "stride of SUPERR axes doesn't match super type");
  205. b.dim[i].stride = a.stride(i) / r;
  206. b.dim[i].size = a.size(i);
  207. }
  208. RA_CHECK((b.rank()==0 || a.stride(b.rank()-1)==r) && "super type is not compact in array");
  209. b.p = reinterpret_cast<super_t *>(a.data());
  210. return b;
  211. }
  212. template <class super_t, class T, rank_t RANK> inline
  213. auto explode(View<T *, RANK> const & a)
  214. {
  215. return explode_<super_t, (std::is_same_v<super_t, std::complex<T>> ? 1 : rank_s<super_t>())>(a);
  216. }
  217. // FIXME Consider these in as namespace level generics in atom.hh
  218. template <class T> inline int gstride(int i) { if constexpr (is_scalar<T>) return 1; else return T::stride(i); }
  219. template <class T> inline int gsize(int i) { if constexpr (is_scalar<T>) return 1; else return T::size(i); }
  220. // TODO This routine is not totally safe; the ranks below SUBR must be compact, which is not checked.
  221. template <class sub_t, class super_t, rank_t RANK> inline
  222. auto collapse(View<super_t *, RANK> const & a)
  223. {
  224. using super_v = value_t<super_t>;
  225. using sub_v = value_t<sub_t>;
  226. constexpr int subtype = sizeof(super_v)/sizeof(sub_t);
  227. constexpr int SUBR = rank_s<super_t>() - rank_s<sub_t>();
  228. View<sub_t *, rank_sum(RANK, SUBR+int(subtype>1))> b;
  229. resize(b.dim, a.rank()+SUBR+int(subtype>1));
  230. constexpr dim_t r = sizeof(super_t)/sizeof(sub_t);
  231. static_assert(sizeof(super_t)==r*sizeof(sub_t), "cannot make axis of super_t from sub_t");
  232. for (int i=0; i<a.rank(); ++i) {
  233. b.dim[i].stride = a.stride(i) * r;
  234. b.dim[i].size = a.size(i);
  235. }
  236. constexpr int t = sizeof(super_v)/sizeof(sub_v);
  237. constexpr int s = sizeof(sub_t)/sizeof(sub_v);
  238. static_assert(t*sizeof(sub_v)>=1, "bad subtype");
  239. for (int i=0; i<SUBR; ++i) {
  240. RA_CHECK(((gstride<super_t>(i)/s)*s==gstride<super_t>(i)) && "bad strides"); // TODO is actually static
  241. b.dim[a.rank()+i].stride = gstride<super_t>(i) / s * t;
  242. b.dim[a.rank()+i].size = gsize<super_t>(i);
  243. }
  244. if (subtype>1) {
  245. b.dim[a.rank()+SUBR].stride = 1;
  246. b.dim[a.rank()+SUBR].size = t;
  247. }
  248. b.p = reinterpret_cast<sub_t *>(a.data());
  249. return b;
  250. }
  251. // For functions that require compact arrays (TODO they really shouldn't).
  252. template <class A> inline
  253. bool const crm(A const & a)
  254. {
  255. return ra::size(a)==0 || is_c_order(a);
  256. }
  257. } // namespace ra