where.C 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file where.C
  3. /// @brief Tests for where() and pick().
  4. // (c) Daniel Llorens - 2014-2016
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #include <atomic>
  10. #include "ra/operators.H"
  11. #include "ra/io.H"
  12. #include "ra/test.H"
  13. using std::cout, std::endl;
  14. int main()
  15. {
  16. TestRecorder tr(std::cout);
  17. std::atomic<int> counter { 0 };
  18. auto count = [&counter](auto && x) -> decltype(auto) { ++counter; return x; };
  19. tr.section("pick");
  20. {
  21. ra::Small<double, 3> a0 = { 1, 2, 3 };
  22. ra::Small<double, 3> a1 = { 10, 20, 30 };
  23. ra::Small<int, 3> p = { 0, 1, 0 };
  24. ra::Small<double, 3> a(0.);
  25. counter = 0;
  26. a = pick(p, map(count, a0), map(count, a1));
  27. tr.test_eq(ra::Small<double, 3> { 1, 20, 3 }, a);
  28. tr.info("pick ETs execute only one branch per iteration").test_eq(3, int(counter));
  29. counter = 0;
  30. a = where(p, map(count, a0), map(count, a1));
  31. tr.test_eq(ra::Small<double, 3> { 10, 2, 30 }, a);
  32. tr.info("where() is implemented using pick ET").test_eq(3, int(counter));
  33. }
  34. tr.section("write to pick");
  35. {
  36. ra::Small<double, 2> a0 = { 1, 2 };
  37. ra::Small<double, 2> a1 = { 10, 20 };
  38. ra::Small<int, 2> const p = { 0, 1 };
  39. ra::Small<double, 2> const a = { 7, 9 };
  40. counter = 0;
  41. pick(p, map(count, a0), map(count, a1)) = a;
  42. tr.test_eq(2, int(counter));
  43. tr.test_eq(ra::Small<double, 2> { 7, 2 }, a0);
  44. tr.test_eq(ra::Small<double, 2> { 10, 9 }, a1);
  45. tr.test_eq(ra::Small<double, 2> { 7, 9 }, a);
  46. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  47. }
  48. tr.section("pick works as any other array expression");
  49. {
  50. ra::Small<double, 2> a0 = { 1, 2 };
  51. ra::Small<double, 2> a1 = { 10, 20 };
  52. ra::Small<int, 2> const p = { 0, 1 };
  53. ra::Small<double, 2> q = 3 + pick(p, a0, a1);
  54. tr.test_eq(ra::Small<int, 2> { 4, 23 }, q);
  55. }
  56. tr.section("pick with TensorIndex");
  57. {
  58. ra::Small<double, 2> a0 = { 1, 2 };
  59. ra::Small<double, 2> a1 = { 10, 20 };
  60. ra::Small<int, 2> const p = { 0, 1 };
  61. counter = 0;
  62. pick(p, map(count, a0), map(count, a1)) += ra::_0+5;
  63. tr.test_eq(2, int(counter));
  64. tr.test_eq(ra::Small<double, 2> { 6, 2 }, a0);
  65. tr.test_eq(ra::Small<double, 2> { 10, 26 }, a1);
  66. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  67. }
  68. tr.section("where, scalar W, array arguments in T/F");
  69. {
  70. std::array<double, 2> bb {1, 2};
  71. std::array<double, 2> cc {99, 99};
  72. auto b = ra::start(bb);
  73. auto c = ra::start(cc);
  74. cc[0] = cc[1] = 99;
  75. c = where(true, b, -b);
  76. tr.test_eq(1, cc[0]);
  77. tr.test_eq(2, cc[1]);
  78. // test against a bug where the op in where()'s Expr returned a dangling reference when both its args are rvalue refs. This was visible only at certain -O levels.
  79. cc[0] = cc[1] = 99;
  80. c = where(true, b-3, -b);
  81. tr.test_eq(-2, cc[0]);
  82. tr.test_eq(-1, cc[1]);
  83. }
  84. tr.section("where as rvalue");
  85. {
  86. tr.test_eq(ra::Unique<int, 1> { 1, 2, 2, 1 }, where(ra::Unique<bool, 1> { true, false, false, true }, 1, 2));
  87. tr.test_eq(ra::Unique<int, 1> { 17, 2, 3, 17 }
  88. , where(ra::_0>0 && ra::_0<3, ra::Unique<int, 1> { 1, 2, 3, 4 }, 17));
  89. // [raop00] TensorIndex returs value; so where()'s lambda must also return value.
  90. tr.test_eq(ra::Unique<int, 1> { 1, 2, 4, 7 }, where(ra::Unique<bool, 1> { true, false, false, true }, 2*ra::_0+1, 2*ra::_0));
  91. // Using frame matching... TODO directly with ==expr?
  92. ra::Unique<int, 2> a({4, 3}, ra::_0-ra::_1);
  93. ra::Unique<int, 2> b = where(ra::Unique<bool, 1> { true, false, false, true }, 99, a);
  94. tr.test_eq(ra::Unique<int, 2> ({4, 3}, { 99, 99, 99, 1, 0, -1, 2, 1, 0, 99, 99, 99 }), b);
  95. }
  96. tr.section("where nested");
  97. {
  98. {
  99. ra::Small<int, 3> a {-1, 0, 1};
  100. ra::Small<int, 3> b = where(a>=0, where(a<1, 77, 99), 44);
  101. tr.test_eq(ra::Small<int, 3> {44, 77, 99}, b);
  102. }
  103. {
  104. int a = 0;
  105. ra::Small<int, 2, 2> b = where(a>=0, where(a>=1, 99, 77), 44);
  106. tr.test_eq(ra::Small<int, 2, 2> {77, 77, 77, 77}, b);
  107. }
  108. }
  109. tr.section("where, scalar W, array arguments in T/F");
  110. {
  111. double a = 1./7;
  112. ra::Small<double, 2> b {1, 2};
  113. ra::Small<double, 2> c = where(a>0, b, 3.);
  114. tr.test_eq(ra::Small<double, 2> {1, 2}, c);
  115. }
  116. tr.section("where as lvalue, scalar");
  117. {
  118. double a=0, b=0;
  119. bool w = true;
  120. where(w, a, b) = 99;
  121. tr.test_eq(a, 99);
  122. tr.test_eq(b, 0);
  123. where(!w, a, b) = 77;
  124. tr.test_eq(99, a);
  125. tr.test_eq(77, b);
  126. }
  127. tr.section("where, scalar + rank 0 array");
  128. {
  129. ra::Small<double> a { 33. };
  130. double b = 22.;
  131. tr.test_eq(33, double(where(true, a, b)));
  132. tr.test_eq(22, double(where(true, b, a)));
  133. }
  134. tr.section("where as lvalue, xpr [raop01]");
  135. {
  136. ra::Unique<int, 1> a { 0, 0, 0, 0 };
  137. ra::Unique<int, 1> b { 0, 0, 0, 0 };
  138. where(ra::_0>0 && ra::_0<3, a, b) = 7;
  139. tr.test_eq(ra::Unique<int, 1> {0, 7, 7, 0}, a);
  140. tr.test_eq(ra::Unique<int, 1> {7, 0, 0, 7}, b);
  141. where(ra::_0<=0 || ra::_0>=3, a, b) += 2;
  142. tr.test_eq(ra::Unique<int, 1> {2, 7, 7, 2}, a);
  143. tr.test_eq(ra::Unique<int, 1> {7, 2, 2, 7}, b);
  144. // Both must be lvalues; TODO check that either of these is an error.
  145. // where(ra::_0>0 && ra::_0<3, ra::_0, a) = 99;
  146. // where(ra::_0>0 && ra::_0<3, a, ra::_0) = 99;
  147. }
  148. tr.section("where with rvalue TensorIndex, fails to compile with g++ 5.2 -Os, gives wrong result with -O0");
  149. {
  150. tr.test_eq(ra::Small<int, 2> {0, 1},
  151. where(ra::Unique<bool, 1> { true, false }, ra::TensorIndex<0>(), ra::TensorIndex<0>()));
  152. tr.test_eq(ra::Unique<int, 1> { 0, 2 }, where(ra::Unique<bool, 1> { true, false }, 3*ra::_0, 2*ra::_0));
  153. }
  154. tr.section("&& and || are short-circuiting");
  155. {
  156. using bool4 = ra::Small<bool, 4>;
  157. bool4 a {true, true, false, false}, b {true, false, true, false};
  158. int i = 0;
  159. tr.test_eq(bool4 {true, false, false, false}, a && map([&](auto && b) { ++i; return b; }, b));
  160. tr.info("short circuit test for &&").test_eq(2, i);
  161. i = 0;
  162. tr.test_eq(bool4 {true, true, true, false}, a || map([&](auto && b) { ++i; return b; }, b));
  163. tr.info("short circuit test for &&").test_eq(2, i);
  164. }
  165. // These tests should fail at compile time. No way to check them yet [ra42].
  166. // tr.section("size checks");
  167. // {
  168. // ra::Small<int, 3> a = { 1, 2, 3 };
  169. // ra::Small<int, 3> b = { 4, 5, 6 };
  170. // ra::Small<int, 2> c = 0; // ok if 2 -> 3; the test is for that case.
  171. // where(a>b, a, c) += b;
  172. // tr.test_eq(ra::Small<int, 3> { 1, 2, 3 }, a);
  173. // tr.test_eq(ra::Small<int, 3> { 4, 5, 6 }, b);
  174. // }
  175. return tr.summary();
  176. }