test-wrank.C 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. // (c) Daniel Llorens - 2013-2015
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file test-wrank.C
  7. /// @brief Checks for ra:: arrays, especially cell rank > 0 operations.
  8. #include <iostream>
  9. #include <sstream>
  10. #include <iterator>
  11. #include <numeric>
  12. #include <atomic>
  13. #include "ra/mpdebug.H"
  14. #include "ra/complex.H"
  15. #include "ra/format.H"
  16. #include "ra/test.H"
  17. #include "ra/big.H"
  18. #include "ra/wrank.H"
  19. #include "ra/operators.H"
  20. #include "ra/io.H"
  21. using std::cout; using std::endl; using std::flush;
  22. using std::tuple; using real = double;
  23. using ra::dim_t;
  24. // Find the driver for given axis. This pattern is used in Ryn to find the size-giving argument for each axis.
  25. template <int iarg, class T>
  26. std::enable_if_t<(iarg==mp::len<std::decay_t<T>>), int>
  27. constexpr driver(T && t, int k)
  28. {
  29. assert(0 && "there was no driver"); abort();
  30. }
  31. template <int iarg, class T>
  32. std::enable_if_t<(iarg<mp::len<std::decay_t<T>>), int>
  33. constexpr driver(T && t, int k)
  34. {
  35. dim_t s = std::get<iarg>(t).size(k);
  36. return s>=0 ? iarg : driver<iarg+1>(t, k);
  37. }
  38. template <class FM, class Enable=void> struct DebugFrameMatch
  39. {
  40. constexpr static bool terminal = true;
  41. using R = typename FM::R;
  42. constexpr static int depth = FM::depth;
  43. using framedrivers = mp::int_list<FM::driver>;
  44. using axisdrivers = mp::MakeList_<mp::Ref_<typename FM::live, FM::driver>::value, mp::int_t<FM::driver>>;
  45. using axisaxes = mp::Iota_<mp::Ref_<typename FM::live, FM::driver>::value, mp::len<mp::Ref_<typename FM::R_, FM::driver>>>;
  46. using argindices = mp::Zip_<axisdrivers, axisaxes>;
  47. };
  48. template <class FM> struct DebugFrameMatch<FM, std::enable_if_t<mp::exists<typename FM::FM> > >
  49. {
  50. using FMC = typename FM::FM;
  51. using DFMC = DebugFrameMatch<FMC>;
  52. constexpr static bool terminal = false;
  53. using R = typename FM::R;
  54. constexpr static int depth = FM::depth;
  55. using framedrivers = mp::Cons_<mp::int_t<FM::driver>, typename DFMC::framedrivers>;
  56. using axisdrivers = mp::Append_<mp::MakeList_<mp::Ref_<typename FM::live, FM::driver>::value, mp::int_t<FM::driver>>,
  57. typename DFMC::axisdrivers>;
  58. using axisaxes = mp::Append_<mp::Iota_<mp::Ref_<typename FM::live, FM::driver>::value, mp::len<mp::Ref_<typename FM::R_, FM::driver>>>,
  59. typename DFMC::axisaxes>;
  60. using argindices = mp::Zip_<axisdrivers, axisaxes>;
  61. };
  62. template <class V, class A, class B>
  63. void framematch_demo(V && v, A && a, B && b)
  64. {
  65. using FM = ra::Framematch<std::decay_t<V>, tuple<decltype(a.iter()), decltype(b.iter())>>;
  66. using DFM = DebugFrameMatch<FM>;
  67. cout << "FM is terminal: " << DFM::terminal << endl;
  68. cout << "width of fm: " << mp::len<typename DFM::R> << ", depth: " << DFM::depth << endl;
  69. cout << "FM::R: " << mp::print_int_list<typename DFM::R> {} << endl;
  70. cout << "FM::framedrivers: " << mp::print_int_list<typename DFM::framedrivers> {} << endl;
  71. cout << "FM::axisdrivers: " << mp::print_int_list<typename DFM::axisdrivers> {} << endl;
  72. cout << "FM::axisaxes: " << mp::print_int_list<typename DFM::axisaxes> {} << endl;
  73. cout << "FM::argindices: " << mp::print_int_list<typename DFM::argindices> {} << endl;
  74. cout << endl;
  75. }
  76. template <class V, class A, class B>
  77. void nested_wrank_demo(V && v, A && a, B && b)
  78. {
  79. std::iota(a.begin(), a.end(), 10);
  80. std::iota(b.begin(), b.end(), 1);
  81. {
  82. using FM = ra::Framematch<V, tuple<decltype(a.iter()), decltype(b.iter())>>;
  83. cout << "width of fm: " << mp::len<typename FM::R> << ", depth: " << FM::depth << endl;
  84. cout << mp::print_int_list<typename FM::R> {} << endl;
  85. auto af0 = ra::applyframes<mp::Ref_<typename FM::R, 0>, FM::depth>::f(a.iter());
  86. auto af1 = ra::applyframes<mp::Ref_<typename FM::R, 1>, FM::depth>::f(b.iter());
  87. cout << sizeof(af0) << endl;
  88. cout << sizeof(af1) << endl;
  89. {
  90. auto ryn = ra::ryn<FM>(FM::op(v), af0, af1);
  91. cout << sizeof(ryn) << endl;
  92. cout << "ryn rank: " << ryn.rank() << endl;
  93. for (int k=0; k<ryn.rank(); ++k) {
  94. cout << ryn.size(k) << ": " << driver<0>(ryn.t, k) << endl;
  95. }
  96. // cout << mp::show_type<decltype(ra::ryn<FM>(FM::op(v), af0, af1))>::value << endl;
  97. cout << "\nusing (ryn &):\n";
  98. ra::ply_ravel(ryn);
  99. cout << endl;
  100. cout << "\nusing (ryn &&):\n";
  101. ra::ply_ravel(ra::ryn<FM>(FM::op(v), af0, af1));
  102. }
  103. {
  104. // cout << mp::show_type<decltype(ra::expr(v, a.iter(), b.iter()))>::value << endl;
  105. auto ryn = ra::expr(v, a.iter(), b.iter());
  106. cout << "ryn.shape(): " << ra::format_array(ryn.shape(), false) << endl;
  107. #define TEST(plier) \
  108. cout << "\n\nusing " STRINGIZE(plier) " (ryn &):\n"; \
  109. ra::plier(ryn); \
  110. cout << "\n\nusing " STRINGIZE(plier) " ply (ryn &&):\n"; \
  111. ra::plier(ra::expr(v, a.iter(), b.iter()));
  112. TEST(ply_ravel);
  113. TEST(ply_index);
  114. TEST(plyf);
  115. TEST(plyf_index);
  116. }
  117. cout << "\n\n" << endl;
  118. }
  119. }
  120. int main()
  121. {
  122. TestRecorder tr;
  123. auto plus2real = [](real a, real b) { return a + b; };
  124. tr.section("declaring verbs");
  125. {
  126. auto v = ra::wrank<0, 1>(plus2real);
  127. cout << mp::Ref_<decltype(v)::R, 0>::value << endl;
  128. cout << mp::Ref_<decltype(v)::R, 1>::value << endl;
  129. auto vv = ra::wrank<1, 1>(v);
  130. cout << mp::Ref_<decltype(vv)::R, 0>::value << endl;
  131. cout << mp::Ref_<decltype(vv)::R, 1>::value << endl;
  132. }
  133. tr.section("using Framematch");
  134. {
  135. ra::Unique<real, 2> a({3, 2}, ra::unspecified);
  136. ra::Unique<real, 2> b({3, 2}, ra::unspecified);
  137. std::iota(a.begin(), a.end(), 10);
  138. std::iota(b.begin(), b.end(), 1);
  139. {
  140. framematch_demo(plus2real, a, b);
  141. framematch_demo(ra::wrank<0, 0>(plus2real), a, b);
  142. framematch_demo(ra::wrank<0, 1>(plus2real), a, b);
  143. framematch_demo(ra::wrank<1, 0>(plus2real), a, b);
  144. framematch_demo(ra::wrank<1, 1>(plus2real), a, b);
  145. }
  146. auto plus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  147. {
  148. auto v = ra::wrank<0, 2>(plus2real_print);
  149. using FM = ra::Framematch<decltype(v), tuple<decltype(a.iter()), decltype(b.iter())>>;
  150. cout << "width of fm: " << mp::len<FM::R> << ", depth: " << FM::depth << endl;
  151. cout << mp::print_int_list<FM::R> {} << endl;
  152. auto af0 = ra::applyframes<mp::Ref_<FM::R, 0>, FM::depth>::f(a.iter());
  153. auto af1 = ra::applyframes<mp::Ref_<FM::R, 1>, FM::depth>::f(b.iter());
  154. cout << sizeof(af0) << endl;
  155. cout << sizeof(af1) << endl;
  156. auto ryn = ra::ryn<FM>(FM::op(v), af0, af1);
  157. cout << sizeof(ryn) << "\n" << endl;
  158. cout << "ryn rank: " << ryn.rank() << endl;
  159. for (int k=0; k<ryn.rank(); ++k) {
  160. cout << ryn.size(k) << ": " << driver<0>(ryn.t, k) << endl;
  161. }
  162. ra::ply_ravel(ryn);
  163. }
  164. }
  165. tr.section("wrank tests 0-1");
  166. {
  167. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  168. nested_wrank_demo(ra::wrank<0, 1>(minus2real_print),
  169. ra::Unique<real, 1>({3}, ra::unspecified),
  170. ra::Unique<real, 1>({4}, ra::unspecified));
  171. nested_wrank_demo(ra::wrank<0, 1>(ra::wrank<0, 0>(minus2real_print)),
  172. ra::Unique<real, 1>({3}, ra::unspecified),
  173. ra::Unique<real, 1>({3}, ra::unspecified));
  174. }
  175. tr.section("wrank tests 1-0");
  176. {
  177. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  178. nested_wrank_demo(ra::wrank<1, 0>(minus2real_print),
  179. ra::Unique<real, 1>({3}, ra::unspecified),
  180. ra::Unique<real, 1>({4}, ra::unspecified));
  181. nested_wrank_demo(ra::wrank<1, 0>(ra::wrank<0, 0>(minus2real_print)),
  182. ra::Unique<real, 1>({3}, ra::unspecified),
  183. ra::Unique<real, 1>({4}, ra::unspecified));
  184. }
  185. tr.section("wrank tests 0-0 (nop), case 1 - exact match");
  186. {
  187. // This uses the applyframes specialization for 'do nothing' (TODO if there's one).
  188. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  189. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  190. ra::Unique<real, 1>({3}, ra::unspecified),
  191. ra::Unique<real, 1>({3}, ra::unspecified));
  192. }
  193. tr.section("wrank tests 0-0 (nop), case 2 - non-exact frame match");
  194. {
  195. // This uses the applyframes specialization for 'do nothing' (TODO if there's one).
  196. auto minus2real_print = [](real a, real b) { cout << (a - b) << " "; };
  197. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  198. ra::Unique<real, 2>({3, 4}, ra::unspecified),
  199. ra::Unique<real, 1>({3}, ra::unspecified));
  200. nested_wrank_demo(ra::wrank<0, 0>(minus2real_print),
  201. ra::Unique<real, 1>({3}, ra::unspecified),
  202. ra::Unique<real, 2>({3, 4}, ra::unspecified));
  203. }
  204. tr.section("wrank tests 1-1-0, init array with outer product");
  205. {
  206. auto minus2real = [](real & c, real a, real b) { c = a-b; };
  207. ra::Unique<real, 1> a({3}, ra::unspecified);
  208. ra::Unique<real, 1> b({4}, ra::unspecified);
  209. std::iota(a.begin(), a.end(), 10);
  210. std::iota(b.begin(), b.end(), 1);
  211. cout << "a: " << a << endl;
  212. cout << "b: " << b << endl;
  213. ra::Unique<real, 2> c({3, 4}, ra::unspecified);
  214. ra::ply(ra::expr(ra::wrank<1, 0, 1>(minus2real), c.iter(), a.iter(), b.iter()));
  215. cout << "c: " << c << endl;
  216. real checkc34[3*4] = { /* 10-[1 2 3 4] */ 9, 8, 7, 6,
  217. /* 11-[1 2 3 4] */ 10, 9, 8, 7,
  218. /* 12-[1 2 3 4] */ 11, 10, 9, 8 };
  219. tr.test(std::equal(checkc34, checkc34+3*4, c.begin()));
  220. ra::Unique<real, 2> d34(ra::expr(ra::wrank<0, 1>(std::minus<real>()), a.iter(), b.iter()));
  221. cout << "d34: " << d34 << endl;
  222. tr.test(std::equal(checkc34, checkc34+3*4, d34.begin()));
  223. real checkc43[3*4] = { /* [10 11 12]-1 */ 9, 10, 11,
  224. /* [10 11 12]-2 */ 8, 9, 10,
  225. /* [10 11 12]-3 */ 7, 8, 9,
  226. /* [10 11 12]-4 */ 6, 7, 8 };
  227. ra::Unique<real, 2> d43(ra::expr(ra::wrank<1, 0>(std::minus<real>()), a.iter(), b.iter()));
  228. cout << "d43: " << d43 << endl;
  229. tr.test(d43.size(0)==4 && d43.size(1)==3);
  230. tr.test(std::equal(checkc43, checkc43+3*4, d43.begin()));
  231. }
  232. tr.section("recipe for unbeatable subscripts in _from_ operator");
  233. {
  234. ra::Unique<int, 1> a({3}, ra::unspecified);
  235. ra::Unique<int, 1> b({4}, ra::unspecified);
  236. std::iota(a.begin(), a.end(), 10);
  237. std::iota(b.begin(), b.end(), 1);
  238. ra::Unique<real, 2> c({100, 100}, ra::unspecified);
  239. std::iota(c.begin(), c.end(), 0);
  240. real checkd[3*4] = { 1001, 1002, 1003, 1004, 1101, 1102, 1103, 1104, 1201, 1202, 1203, 1204 };
  241. // default auto is value, so need to speficy.
  242. #define EXPR ra::expr(ra::wrank<0, 1>([&c](int a, int b) -> decltype(auto) { return c(a, b); } ), \
  243. a.iter(), b.iter())
  244. std::ostringstream os;
  245. os << EXPR << endl;
  246. ra::Unique<real, 2> cc {};
  247. std::istringstream is(os.str());
  248. is >> cc;
  249. cout << "cc: " << cc << endl;
  250. tr.test(std::equal(checkd, checkd+3*4, cc.begin()));
  251. ra::Unique<real, 2> d(EXPR);
  252. cout << "d: " << d << endl;
  253. tr.test(std::equal(checkd, checkd+3*4, d.begin()));
  254. // Using expr as lvalue.
  255. EXPR = 7.;
  256. cout << EXPR << endl;
  257. // expr-way BUG use of test_eq fails (??)
  258. assert(every(c==where(ra::_0>=10 && ra::_0<=12 && ra::_1>=1 && ra::_1<=4, 7, ra::_0*100+ra::_1)));
  259. // looping...
  260. bool valid = true;
  261. for (int i=0; i<c.size(0); ++i) {
  262. for (int j=0; j<c.size(1); ++j) {
  263. valid = valid && ((i>=10 && i<=12 && j>=1 && j<=4 ? 7 : i*100+j) == c(i, j));
  264. }
  265. }
  266. tr.test(valid);
  267. }
  268. tr.section("rank conjunction / empty");
  269. {
  270. }
  271. tr.section("static rank() in ra::Ryn");
  272. {
  273. ra::Unique<real, 3> a({2, 2, 2}, 1.);
  274. ra::Unique<real, 3> b({2, 2, 2}, 2.);
  275. real y = 0;
  276. auto e = ra::expr(ra::wrank<0, 0>([&y](real const a, real const b) { y += a*b; }), a.iter(), b.iter());
  277. static_assert(3==e.rank(), "bad rank in static rank expr");
  278. ra::ply_ravel(ra::expr(ra::wrank<0, 0>([&y](real const a, real const b) { y += a*b; }), a.iter(), b.iter()));
  279. tr.test_eq(16, y);
  280. }
  281. tr.section("outer product variants");
  282. {
  283. ra::Big<real, 2> a({2, 3}, ra::_0 - ra::_1);
  284. ra::Big<real, 2> b({3, 2}, ra::_1 - 2*ra::_0);
  285. ra::Big<real, 2> c1 = gemm(a, b);
  286. cout << "matrix a * b: \n" << c1 << endl;
  287. // matrix product as outer product + reduction (no reductions yet, so manually).
  288. {
  289. ra::Big<real, 3> d = ra::expr(ra::wrank<1, 2>(ra::wrank<0, 1>(ra::times())), start(a), start(b));
  290. cout << "d(i,k,j) = a(i,k)*b(k,j): \n" << d << endl;
  291. ra::Big<real, 2> c2({d.size(0), d.size(2)}, 0.);
  292. for (int k=0; k<d.size(1); ++k) {
  293. c2 += d(ra::all, k, ra::all);
  294. }
  295. tr.test_eq(c1, c2);
  296. }
  297. // do the k-reduction by plying with wrank.
  298. {
  299. ra::Big<real, 2> c2({a.size(0), b.size(1)}, 0.);
  300. ra::ply(ra::expr(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto & c, auto && a, auto && b) { c += a*b; })),
  301. start(c2), start(a), start(b)));
  302. cout << "sum_k a(i,k)*b(k,j): \n" << c2 << endl;
  303. tr.test_eq(c1, c2);
  304. }
  305. }
  306. tr.section("stencil test for ApplyFrames::keep_stride. Reduced from test/bench-stencil2.C");
  307. {
  308. int nx = 4;
  309. int ny = 4;
  310. int ts = 4; // must be even b/c of swap
  311. auto I = ra::iota(nx-2, 1);
  312. auto J = ra::iota(ny-2, 1);
  313. constexpr ra::Small<real, 3, 3> mask = { 0, 1, 0,
  314. 1, -4, 1,
  315. 0, 1, 0 };
  316. real value = 1;
  317. auto f_raw = [&](ra::View<real, 2> & A, ra::View<real, 2> & Anext, ra::View<real, 4> & Astencil)
  318. {
  319. for (int t=0; t<ts; ++t) {
  320. for (int i=1; i+1<nx; ++i) {
  321. for (int j=1; j+1<ny; ++j) {
  322. Anext(i, j) = -4*A(i, j)
  323. + A(i+1, j) + A(i, j+1)
  324. + A(i-1, j) + A(i, j-1);
  325. }
  326. }
  327. std::swap(A.p, Anext.p);
  328. }
  329. };
  330. auto f_sumprod = [&](ra::View<real, 2> & A, ra::View<real, 2> & Anext, ra::View<real, 4> & Astencil)
  331. {
  332. for (int t=0; t!=ts; ++t) {
  333. Astencil.p = A.data();
  334. Anext(I, J) = 0; // TODO miss notation for sum-of-axes without preparing destination...
  335. Anext(I, J) += map(ra::wrank<2, 2>(ra::times()), Astencil, mask);
  336. std::swap(A.p, Anext.p);
  337. }
  338. };
  339. auto bench = [&](auto & A, auto & Anext, auto & Astencil, auto && ref, auto && tag, auto && f)
  340. {
  341. A = value;
  342. Anext = 0.;
  343. f(A, Anext, Astencil);
  344. tr.info(tag).test_rel_error(ref, A, 1e-11);
  345. };
  346. ra::Big<real, 2> Aref;
  347. ra::Big<real, 2> A({nx, ny}, 1.);
  348. ra::Big<real, 2> Anext({nx, ny}, 0.);
  349. auto Astencil = stencil(A, 1, 1);
  350. cout << "Astencil " << format_array(Astencil(0, 0, ra::dots<2>), true, "|", " ") << endl;
  351. #define BENCH(ref, op) bench(A, Anext, Astencil, ref, STRINGIZE(op), op);
  352. BENCH(A, f_raw);
  353. Aref = ra::Big<real, 2>(A);
  354. BENCH(Aref, f_sumprod);
  355. }
  356. tr.section("Iota with dead axes");
  357. {
  358. ra::Big<int, 2> a = from([](auto && i, auto && j) { return i-j; }, ra::iota(3), ra::iota(3));
  359. tr.test_eq(ra::Big<int, 2>({3, 3}, {0, -1, -2, 1, 0, -1, 2, 1, 0}), a);
  360. }
  361. tr.section("Vector with dead axes");
  362. {
  363. std::vector<int> i = {0, 1, 2};
  364. ra::Big<int, 2> a = ra::from([](auto && i, auto && j) { return i-j; }, i, i);
  365. tr.test_eq(ra::Big<int, 2>({3, 3}, {0, -1, -2, 1, 0, -1, 2, 1, 0}), a);
  366. }
  367. tr.section("no arguments -> zero rank");
  368. {
  369. int x = ra::from([]() { return 3; });
  370. tr.test_eq(3, x);
  371. }
  372. tr.section("counting ops");
  373. {
  374. std::atomic<int> i { 0 };
  375. auto fi = [&i](auto && x) { ++i; return x; };
  376. std::atomic<int> j { 0 };
  377. auto fj = [&j](auto && x) { ++j; return x; };
  378. ra::Big<int, 2> a = from(ra::minus(), map(fi, ra::iota(7)), map(fj, ra::iota(9)));
  379. tr.test_eq(ra::_0-ra::_1, a);
  380. tr.info("FIXME").skip().test_eq(7, int(i));
  381. tr.info("FIXME").skip().test_eq(9, int(j));
  382. }
  383. return tr.summary();
  384. }