from.C 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file from.C
  3. /// @brief Checks for index selectors, both immediate and delayed.
  4. // (c) Daniel Llorens - 2014
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #include <iostream>
  10. #include <iterator>
  11. #include "ra/mpdebug.H"
  12. #include "ra/complex.H"
  13. #include "ra/format.H"
  14. #include "ra/test.H"
  15. #include "ra/big.H"
  16. #include "ra/operators.H"
  17. #include "ra/io.H"
  18. using std::cout, std::endl, std::flush, std::tuple;
  19. using real = double;
  20. template <int rank=ra::RANK_ANY> using Ureal = ra::Unique<real, rank>;
  21. using Vint = ra::Unique<int, 1>;
  22. int main()
  23. {
  24. TestRecorder tr(std::cout);
  25. tr.section("shortcuts");
  26. {
  27. auto check_selection_shortcuts = [&tr](auto && a)
  28. {
  29. tr.info("a()").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a());
  30. tr.info("a(2, :)").test_eq(Ureal<1>({4}, 2-ra::_0), a(2, ra::all));
  31. tr.info("a(2)").test_eq(Ureal<1>({4}, 2-ra::_0), a(2));
  32. tr.info("a(:, 3)").test_eq(Ureal<1>({4}, ra::_0-3), a(ra::all, 3));
  33. tr.info("a(:, :)").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a(ra::all, ra::all));
  34. tr.info("a(:)").test_eq(Ureal<2>({4, 4}, ra::_0-ra::_1), a(ra::all));
  35. tr.info("a(1)").test_eq(Ureal<1>({4}, 1-ra::_0), a(1));
  36. tr.info("a(2, 2)").test_eq(0, a(2, 2));
  37. tr.info("a(0:2:, 0:2:)").test_eq(Ureal<2>({2, 2}, 2*(ra::_0-ra::_1)),
  38. a(ra::iota(2, 0, 2), ra::iota(2, 0, 2)));
  39. tr.info("a(1:2:, 0:2:)").test_eq(Ureal<2>({2, 2}, 2*ra::_0+1-2*ra::_1),
  40. a(ra::iota(2, 1, 2), ra::iota(2, 0, 2)));
  41. tr.info("a(0:2:, :)").test_eq(Ureal<2>({2, 4}, 2*ra::_0-ra::_1),
  42. a(ra::iota(2, 0, 2), ra::all));
  43. tr.info("a(0:2:)").test_eq(a(ra::iota(2, 0, 2), ra::all), a(ra::iota(2, 0, 2)));
  44. };
  45. check_selection_shortcuts(Ureal<2>({4, 4}, ra::_0-ra::_1));
  46. check_selection_shortcuts(Ureal<>({4, 4}, ra::_0-ra::_1));
  47. }
  48. tr.section("ra::Iota<int> or ra::Iota<ra::dim_t> are both beatable");
  49. {
  50. Ureal<2> a({4, 4}, 0.);
  51. {
  52. ra::Iota<int> i(2, 1);
  53. auto b = a(i);
  54. tr.test_eq(2, b.dim[0].size);
  55. tr.test_eq(4, b.dim[1].size);
  56. tr.test_eq(4, b.dim[0].stride);
  57. tr.test_eq(1, b.dim[1].stride);
  58. }
  59. {
  60. ra::Iota<ra::dim_t> i(2, 1);
  61. auto b = a(i);
  62. tr.test_eq(2, b.dim[0].size);
  63. tr.test_eq(4, b.dim[1].size);
  64. tr.test_eq(4, b.dim[0].stride);
  65. tr.test_eq(1, b.dim[1].stride);
  66. }
  67. }
  68. tr.section("trivial case");
  69. {
  70. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  71. tr.test_eq(ra::_0*100 + ra::_1*10 + ra::_2, from(a));
  72. }
  73. tr.section("beatable multi-axis selectors, var size");
  74. {
  75. static_assert(ra::is_beatable<ra::dots_t<0>>::value, "dots_t<0> is beatable");
  76. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  77. tr.info("a(ra::dots<0> ...)").test_eq(a(0), a(ra::dots<0>, 0));
  78. tr.info("a(ra::dots<0> ...)").test_eq(a(1), a(ra::dots<0>, 1));
  79. tr.info("a(ra::dots<1> ...)").test_eq(a(ra::all, 0), a(ra::dots<1>, 0));
  80. tr.info("a(ra::dots<1> ...)").test_eq(a(ra::all, 1), a(ra::dots<1>, 1));
  81. tr.info("a(ra::dots<2> ...)").test_eq(a(ra::all, ra::all, 0), a(ra::dots<2>, 0));
  82. tr.info("a(ra::dots<2> ...)").test_eq(a(ra::all, ra::all, 1), a(ra::dots<2>, 1));
  83. tr.info("a(0)").test_eq(a(0, ra::all, ra::all), a(0));
  84. tr.info("a(1)").test_eq(a(1, ra::all, ra::all), a(1));
  85. tr.info("a(0, ra::dots<2>)").test_eq(a(0, ra::all, ra::all), a(0, ra::dots<2>));
  86. tr.info("a(1, ra::dots<2>)").test_eq(a(1, ra::all, ra::all), a(1, ra::dots<2>));
  87. }
  88. tr.section("beatable multi-axis selectors, fixed size");
  89. {
  90. static_assert(ra::is_beatable<ra::dots_t<0>>::value, "dots_t<0> is beatable");
  91. ra::Small<int, 2, 3, 4> a = ra::_0*100 + ra::_1*10 + ra::_2;
  92. tr.info("a(ra::dots<0> ...)").test_eq(a(0), a(ra::dots<0>, 0));
  93. tr.info("a(ra::dots<0> ...)").test_eq(a(1), a(ra::dots<0>, 1));
  94. tr.info("a(ra::dots<1> ...)").test_eq(a(ra::all, 0), a(ra::dots<1>, 0));
  95. tr.info("a(ra::dots<1> ...)").test_eq(a(ra::all, 1), a(ra::dots<1>, 1));
  96. tr.info("a(ra::dots<2> ...)").test_eq(a(ra::all, ra::all, 0), a(ra::dots<2>, 0));
  97. tr.info("a(ra::dots<2> ...)").test_eq(a(ra::all, ra::all, 1), a(ra::dots<2>, 1));
  98. tr.info("a(0)").test_eq(a(0, ra::all, ra::all), a(0));
  99. tr.info("a(1)").test_eq(a(1, ra::all, ra::all), a(1));
  100. tr.info("a(0, ra::dots<2>)").test_eq(a(0, ra::all, ra::all), a(0, ra::dots<2>));
  101. tr.info("a(1, ra::dots<2>)").test_eq(a(1, ra::all, ra::all), a(1, ra::dots<2>));
  102. }
  103. tr.section("insert, var size");
  104. {
  105. static_assert(ra::is_beatable<ra::insert_t<1>>::value, "insert_t<1> is beatable");
  106. ra::Big<int, 3> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  107. tr.info("a(ra::insert<0> ...)").test_eq(a(0), a(ra::insert<0>, 0));
  108. ra::Big<int, 4> a1({1, 2, 3, 4}, ra::_1*100 + ra::_2*10 + ra::_3);
  109. tr.info("a(ra::insert<1> ...)").test_eq(a1, a(ra::insert<1>));
  110. ra::Big<int, 4> a2({2, 1, 3, 4}, ra::_0*100 + ra::_2*10 + ra::_3);
  111. tr.info("a(ra::all, ra::insert<1>, ...)").test_eq(a2, a(ra::all, ra::insert<1>));
  112. ra::Big<int, 5> a3({2, 1, 1, 3, 4}, ra::_0*100 + ra::_3*10 + ra::_4);
  113. tr.info("a(ra::all, ra::insert<2>, ...)").test_eq(a3, a(ra::all, ra::insert<2>));
  114. tr.info("a(0, ra::insert<1>, ...)").test_eq(a1(ra::all, 0), a(0, ra::insert<1>));
  115. tr.info("a(ra::insert<1>, 0, ...)").test_eq(a1(ra::all, 0), a(ra::insert<1>, 0));
  116. ra::Big<int, 4> aa1({2, 2, 3, 4}, a(ra::insert<1>));
  117. tr.info("insert with undefined size 0").test_eq(a, aa1(0));
  118. tr.info("insert with undefined size 1").test_eq(a, aa1(1));
  119. }
  120. tr.section("insert, var rank");
  121. {
  122. static_assert(ra::is_beatable<ra::insert_t<1>>::value, "insert_t<1> is beatable");
  123. ra::Big<int> a({2, 3, 4}, ra::_0*100 + ra::_1*10 + ra::_2);
  124. tr.info("a(ra::insert<0> ...)").test_eq(a(0), a(ra::insert<0>, 0));
  125. ra::Big<int> a1({1, 2, 3, 4}, ra::_1*100 + ra::_2*10 + ra::_3);
  126. tr.info("a(ra::insert<1> ...)").test_eq(a1, a(ra::insert<1>));
  127. ra::Big<int> a2({2, 1, 3, 4}, ra::_0*100 + ra::_2*10 + ra::_3);
  128. tr.info("a(ra::all, ra::insert<1>, ...)").test_eq(a2, a(ra::all, ra::insert<1>));
  129. ra::Big<int> a3({2, 1, 1, 3, 4}, ra::_0*100 + ra::_3*10 + ra::_4);
  130. tr.info("a(ra::all, ra::insert<2>, ...)").test_eq(a3, a(ra::all, ra::insert<2>));
  131. tr.info("a(0, ra::insert<1>, ...)").test_eq(a1(ra::all, 0), a(0, ra::insert<1>));
  132. tr.info("a(ra::insert<1>, 0, ...)").test_eq(a1(ra::all, 0), a(ra::insert<1>, 0));
  133. }
  134. tr.section("unbeatable, 1D");
  135. {
  136. auto check_selection_unbeatable_1 = [&tr](auto && a)
  137. {
  138. using CT = ra::Small<real, 4>;
  139. tr.info("a(i ...)").test_eq(CT {a[3], a[2], a[0], a[1]}, a(Vint {3, 2, 0, 1}));
  140. tr.info("a(i ...)").test_eq(CT {a[3], a[2], a[0], a[1]}, from(a, Vint {3, 2, 0, 1}));
  141. a = 0.;
  142. a(Vint {3, 2, 0, 1}) = CT {9, 7, 1, 4};
  143. tr.info("a(i ...) as lvalue").test_eq(CT {1, 4, 7, 9}, a);
  144. a = 0.;
  145. from(a, Vint {3, 2, 0, 1}) = CT {9, 7, 1, 4};
  146. tr.info("from(a i ...) as lvalue").test_eq(CT {1, 4, 7, 9}, a);
  147. a = 0.;
  148. from(a, Vint {3, 2, 0, 1}) = 77.;
  149. tr.info("from(a i ...) as lvalue, rank extend of right hand").test_eq(a, 77.);
  150. ra::Small<real, 2, 2> c = from(a, ra::Small<int, 2, 2> {3, 2, 0, 1});
  151. tr.info("a([x y; z w])").test_eq(ra::Small<real, 2, 2> {a[3], a[2], a[0], a[1]}, c);
  152. };
  153. check_selection_unbeatable_1(Ureal<1> {7, 9, 3, 4});
  154. check_selection_unbeatable_1(ra::Small<real, 4> {7, 9, 3, 4});
  155. check_selection_unbeatable_1(Ureal<>({4}, {7, 9, 3, 4}));
  156. }
  157. tr.section("unbeatable, 2D");
  158. {
  159. auto check_selection_unbeatable_2 = [&tr](auto && a)
  160. {
  161. using CT22 = ra::Small<real, 2, 2>;
  162. using CT2 = ra::Small<real, 2>;
  163. tr.info("a([0 1], [0 1])").test_eq(CT22 {a(0, 0), a(0, 1), a(1, 0), a(1, 1)},
  164. from(a, Vint {0, 1}, Vint {0, 1}));
  165. tr.info("a([0 1], [1 0])").test_eq(CT22 {a(0, 1), a(0, 0), a(1, 1), a(1, 0)},
  166. from(a, Vint {0, 1}, Vint {1, 0}));
  167. tr.info("a([1 0], [0 1])").test_eq(CT22 {a(1, 0), a(1, 1), a(0, 0), a(0, 1)},
  168. from(a, Vint {1, 0}, Vint {0, 1}));
  169. tr.info("a([1 0], [1 0])").test_eq(CT22 {a(1, 1), a(1, 0), a(0, 1), a(0, 0)},
  170. from(a, Vint {1, 0}, Vint {1, 0}));
  171. // TODO This is a nested array, which is a problem, we would use it just as from(a, [0 1], [0 1]).
  172. std::cout << "TODO [" << from(a, Vint {0, 1}) << "]" << std::endl;
  173. a = 0.;
  174. from(a, Vint {1, 0}, Vint {1, 0}) = CT22 {9, 7, 1, 4};
  175. tr.info("a([1 0], [1 0]) as lvalue").test_eq(CT22 {4, 1, 7, 9}, a);
  176. from(a, Vint {1, 0}, Vint {1, 0}) *= CT22 {9, 7, 1, 4};
  177. tr.info("a([1 0], [1 0]) as lvalue, *=").test_eq(CT22 {16, 1, 49, 81}, a);
  178. // Note the difference with J amend, which requires x in (x m} y) ~ (y[m] = x) to be a suffix of y[m]; but we apply the general mechanism which is prefix matching.
  179. from(a, Vint {1, 0}, Vint {1, 0}) = CT2 {9, 7};
  180. tr.info("a([1 0], [1 0]) as lvalue, rank extend of right hand").test_eq(CT22 {7, 7, 9, 9}, a);
  181. // TODO Test cases with rank!=1, starting with this couple which should work the same.
  182. std::cout << "-> " << from(a, Vint{1, 0}, 0) << std::endl;
  183. a = CT22 {4, 1, 7, 9};
  184. tr.info("a(rank1, rank0)").test_eq(ra::Small<real, 2>{9, 1}, from(a, Vint{1, 0}, ra::Small<int>(1).iter()));
  185. tr.info("a(rank0, rank1)").test_eq(ra::Small<real, 2>{9, 7}, from(a, ra::Small<int>(1).iter(), Vint{1, 0}));
  186. };
  187. check_selection_unbeatable_2(Ureal<2>({2, 2}, {1, 2, 3, 4}));
  188. check_selection_unbeatable_2(ra::Small<real, 2, 2>({1, 2, 3, 4}));
  189. check_selection_unbeatable_2(Ureal<>({2, 2}, {1, 2, 3, 4}));
  190. }
  191. tr.section("mixed scalar/unbeatable, 2D -> 1D");
  192. {
  193. auto check_selection_unbeatable_mixed = [&tr](auto && a)
  194. {
  195. using CT2 = ra::Small<real, 2>;
  196. tr.info("from(a [0 1], 1)").test_eq(CT2 {a(0, 1), a(1, 1)}, from(a, Vint {0, 1}, 1));
  197. tr.info("from(a [1 0], 1)").test_eq(CT2 {a(1, 1), a(0, 1)}, from(a, Vint {1, 0}, 1));
  198. tr.info("from(a 1, [0 1])").test_eq(CT2 {a(1, 0), a(1, 1)}, from(a, 1, Vint {0, 1}));
  199. tr.info("from(a 1, [1 0])").test_eq(CT2 {a(1, 1), a(1, 0)}, from(a, 1, Vint {1, 0}));
  200. tr.info("a([0 1], 1)").test_eq(CT2 {a(0, 1), a(1, 1)}, a(Vint {0, 1}, 1));
  201. tr.info("a([1 0], 1)").test_eq(CT2 {a(1, 1), a(0, 1)}, a(Vint {1, 0}, 1));
  202. tr.info("a(1, [0 1])").test_eq(CT2 {a(1, 0), a(1, 1)}, a(1, Vint {0, 1}));
  203. tr.info("a(1, [1 0])").test_eq(CT2 {a(1, 1), a(1, 0)}, a(1, Vint {1, 0}));
  204. };
  205. check_selection_unbeatable_mixed(Ureal<2>({2, 2}, {1, 2, 3, 4}));
  206. check_selection_unbeatable_mixed(ra::Small<real, 2, 2>({1, 2, 3, 4}));
  207. }
  208. tr.section("mixed unbeatable/dots, 2D -> 2D (TODO)");
  209. {
  210. // auto check_selection_unbeatable_dots = [&tr](auto && a)
  211. // {
  212. // using CT2 = ra::Small<real, 2>;
  213. // tr.info("a({0, 0}, ra::all)").test_eq(a(CT2 {0, 0}, ra::all), a(CT2 {0, 0}, CT2 {0, 1}));
  214. // tr.info("a({0, 1}, ra::all)").test_eq(a(CT2 {0, 1}, ra::all), a(CT2 {0, 1}, CT2 {0, 1}));
  215. // tr.info("a({1, 0}, ra::all)").test_eq(a(CT2 {1, 0}, ra::all), a(CT2 {1, 0}, CT2 {0, 1}));
  216. // tr.info("a({1, 1}, ra::all)").test_eq(a(CT2 {1, 1}, ra::all), a(CT2 {1, 1}, CT2 {0, 1}));
  217. // };
  218. // TODO doesn't work because dots_t<> can only be beaten on, not iterated on, and the beating cases are missing.
  219. // check_selection_unbeatable_dots(Ureal<2>({2, 2}, {1, 2, 3, 4}));
  220. // check_selection_unbeatable_dots(ra::Small<real, 2, 2>({1, 2, 3, 4}));
  221. }
  222. tr.section("unbeatable, 3D & higher");
  223. {
  224. // see src/test/bench-from.C for examples of higher-D.
  225. }
  226. tr.section("TensorIndex / where TODO elsewhere");
  227. {
  228. Ureal<2> a({4, 4}, 1.);
  229. a(3, 3) = 7.;
  230. tr.test(every(ra::map([](auto a, int i, int j)
  231. {
  232. return a==(i==3 && j==3 ? 7. : 1.);
  233. },
  234. a, ra::_0, ra::_1)));
  235. tr.test_eq(where(ra::_0==3 && ra::_1==3, 7., 1.), a);
  236. }
  237. // The implementation of from() uses FrameMatch / ApplyFrames and can't handle this yet.
  238. tr.section("TensorIndex<i> as subscript, using ra::Expr directly.");
  239. {
  240. auto i = ra::_0;
  241. auto j = ra::_1;
  242. Ureal<2> a({4, 3}, i-j);
  243. Ureal<2> b({3, 4}, 0.);
  244. b = map([&a](int i, int j) { return a(i, j); }, j, i);
  245. tr.test_eq(i-j, a);
  246. tr.test_eq(j-i, b);
  247. }
  248. tr.section("TensorIndex<i> as subscripts, 1 subscript TODO elsewhere");
  249. {
  250. Ureal<1> a {1, 4, 2, 3};
  251. Ureal<1> b({4}, 0.);
  252. // these work b/c there's another term to drive the expr.
  253. b = a(3-ra::_0);
  254. tr.test_eq(Ureal<1> {3, 2, 4, 1}, b);
  255. b(3-ra::_0) = a;
  256. tr.test_eq(Ureal<1> {3, 2, 4, 1}, b);
  257. }
  258. tr.section("TODO TensorIndex<i> as subscripts, 2 subscript (case I)");
  259. {
  260. Ureal<2> a({4, 4}, ra::_0-ra::_1);
  261. Ureal<2> b({4, 4}, -99.);
  262. cout << a << endl;
  263. cout << b << endl;
  264. // b = a(ra::_0, ra::_0);
  265. }
  266. tr.section("TODO TensorIndex<i> as subscripts, 2 subscript (case II)");
  267. {
  268. Ureal<2> a({4, 4}, ra::_0-ra::_1);
  269. Ureal<2> b({4, 4}, 0.);
  270. cout << a << endl;
  271. cout << b << endl;
  272. // TODO these instantiate flat() when they should not (FIXME was for old OldTensorIndex; recheck)
  273. // tr.info("by_index(Ryn)").test(ra::by_index<decltype(a(ra::_1, ra::_0))>);
  274. // cout << mp::ref<decltype(a(ra::_1, ra::_0))>::rank_s() << endl;
  275. // these don't work because a(j, i) has rank 3 = [(w=1)+1 + (w=0)+1] and so it drives, but tensorindex exprs shouldn't ever drive.
  276. // tr.info("by_index(Ryn)").test(ra::by_index<decltype(b+a(ra::_1, ra::_0))>);
  277. // cout << mp::ref<decltype(b+a(ra::_1, ra::_0))::T, 0>::rank_s() << endl;
  278. // cout << mp::ref<decltype(b+a(ra::_1, ra::_0))::T, 1>::rank_s() << endl;
  279. cout << mp::ref<decltype(ra::_1)>::rank_s() << endl;
  280. // b = a(ra::_1, ra::_0);
  281. }
  282. // Small(Iota) isn't beaten because the the output type cannot depend on argument values. So we treat it as a common expr.
  283. tr.section("ra::Small(Iota)");
  284. {
  285. ra::Small<real, 4> a = ra::_0;
  286. tr.test_eq(a(ra::iota(2, 1)), Ureal<1> { 1, 2 });
  287. }
  288. // Indirection operator using list of coordinates.
  289. tr.section("at() indirection");
  290. {
  291. ra::Big<int, 2> A({4, 4}, 0), B({4, 4}, 10*ra::_0 + ra::_1);
  292. using coord = ra::Small<int, 2>;
  293. ra::Big<coord, 1> I = { {1, 1}, {2, 2} };
  294. at(A, I) = at(B, I);
  295. tr.test_eq(ra::Big<int>({4, 4}, {0, 0, 0, 0, /**/ 0, 11, 0, 0, /**/ 0, 0, 22, 0, /**/ 0, 0, 0, 0}), A);
  296. // TODO this is why we need ops to have explicit rank.
  297. at(A, ra::scalar(coord{3, 2})) = 99.;
  298. tr.test_eq(ra::Big<int>({4, 4}, {0, 0, 0, 0, /**/ 0, 11, 0, 0, /**/ 0, 0, 22, 0, /**/ 0, 0, 99, 0}), A);
  299. }
  300. // From the manual [ra30]
  301. {
  302. ra::Big<int, 2> A = {{100, 101}, {110, 111}, {120, 121}};
  303. ra::Big<ra::Small<int, 2>, 2> i = {{{0, 1}, {2, 0}}, {{1, 0}, {2, 1}}};
  304. ra::Big<int, 2> B = at(A, i);
  305. tr.test_eq(ra::Big<int, 2> {{101, 120}, {110, 121}}, at(A, i));
  306. }
  307. return tr.summary();
  308. }