test-where.C 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // (c) Daniel Llorens - 2014-2016
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. /// @file test-where.C
  7. /// @brief Tests for where() and pick().
  8. #include <atomic>
  9. #include "ra/operators.H"
  10. #include "ra/io.H"
  11. #include "ra/test.H"
  12. using std::cout; using std::endl;
  13. int main()
  14. {
  15. TestRecorder tr(std::cout);
  16. std::atomic<int> counter { 0 };
  17. auto count = [&counter](auto && x) -> decltype(auto) { ++counter; return x; };
  18. tr.section("pick");
  19. {
  20. ra::Small<double, 3> a0 = { 1, 2, 3 };
  21. ra::Small<double, 3> a1 = { 10, 20, 30 };
  22. ra::Small<int, 3> p = { 0, 1, 0 };
  23. ra::Small<double, 3> a(0.);
  24. counter = 0;
  25. a = pick(p, map(count, a0), map(count, a1));
  26. tr.test_eq(ra::Small<double, 3> { 1, 20, 3 }, a);
  27. tr.info("pick ETs execute only one branch per iteration").test_eq(3, int(counter));
  28. counter = 0;
  29. a = where(p, map(count, a0), map(count, a1));
  30. tr.test_eq(ra::Small<double, 3> { 10, 2, 30 }, a);
  31. tr.info("where() is implemented using pick ET").test_eq(3, int(counter));
  32. }
  33. tr.section("write to pick");
  34. {
  35. ra::Small<double, 2> a0 = { 1, 2 };
  36. ra::Small<double, 2> a1 = { 10, 20 };
  37. ra::Small<int, 2> const p = { 0, 1 };
  38. ra::Small<double, 2> const a = { 7, 9 };
  39. counter = 0;
  40. pick(p, map(count, a0), map(count, a1)) = a;
  41. tr.test_eq(2, int(counter));
  42. tr.test_eq(ra::Small<double, 2> { 7, 2 }, a0);
  43. tr.test_eq(ra::Small<double, 2> { 10, 9 }, a1);
  44. tr.test_eq(ra::Small<double, 2> { 7, 9 }, a);
  45. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  46. }
  47. tr.section("pick works as any other array expression");
  48. {
  49. ra::Small<double, 2> a0 = { 1, 2 };
  50. ra::Small<double, 2> a1 = { 10, 20 };
  51. ra::Small<int, 2> const p = { 0, 1 };
  52. ra::Small<double, 2> q = 3 + pick(p, a0, a1);
  53. tr.test_eq(ra::Small<int, 2> { 4, 23 }, q);
  54. }
  55. tr.section("pick with TensorIndex");
  56. {
  57. ra::Small<double, 2> a0 = { 1, 2 };
  58. ra::Small<double, 2> a1 = { 10, 20 };
  59. ra::Small<int, 2> const p = { 0, 1 };
  60. counter = 0;
  61. pick(p, map(count, a0), map(count, a1)) += ra::_0+5;
  62. tr.test_eq(2, int(counter));
  63. tr.test_eq(ra::Small<double, 2> { 6, 2 }, a0);
  64. tr.test_eq(ra::Small<double, 2> { 10, 26 }, a1);
  65. tr.test_eq(ra::Small<int, 2> { 0, 1 }, p);
  66. }
  67. tr.section("where, scalar W, array arguments in T/F");
  68. {
  69. std::array<double, 2> bb {1, 2};
  70. std::array<double, 2> cc {99, 99};
  71. auto b = ra::start(bb);
  72. auto c = ra::start(cc);
  73. cc[0] = cc[1] = 99;
  74. c = where(true, b, -b);
  75. tr.test_eq(1, cc[0]);
  76. tr.test_eq(2, cc[1]);
  77. // test against a bug where the op in where()'s Expr returned a dangling reference when both its args are rvalue refs. This was visible only at certain -O levels.
  78. cc[0] = cc[1] = 99;
  79. c = where(true, b-3, -b);
  80. tr.test_eq(-2, cc[0]);
  81. tr.test_eq(-1, cc[1]);
  82. }
  83. tr.section("where as rvalue");
  84. {
  85. tr.test_eq(ra::Unique<int, 1> { 1, 2, 2, 1 }, where(ra::Unique<bool, 1> { true, false, false, true }, 1, 2));
  86. tr.test_eq(ra::Unique<int, 1> { 17, 2, 3, 17 }
  87. , where(ra::_0>0 && ra::_0<3, ra::Unique<int, 1> { 1, 2, 3, 4 }, 17));
  88. // [raop00] TensorIndex returs value; so where()'s lambda must also return value.
  89. tr.test_eq(ra::Unique<int, 1> { 1, 2, 4, 7 }, where(ra::Unique<bool, 1> { true, false, false, true }, 2*ra::_0+1, 2*ra::_0));
  90. // Using frame matching... TODO directly with ==expr?
  91. ra::Unique<int, 2> a({4, 3}, ra::_0-ra::_1);
  92. ra::Unique<int, 2> b = where(ra::Unique<bool, 1> { true, false, false, true }, 99, a);
  93. tr.test_eq(ra::Unique<int, 2> ({4, 3}, { 99, 99, 99, 1, 0, -1, 2, 1, 0, 99, 99, 99 }), b);
  94. }
  95. tr.section("where nested");
  96. {
  97. {
  98. ra::Small<int, 3> a {-1, 0, 1};
  99. ra::Small<int, 3> b = where(a>=0, where(a<1, 77, 99), 44);
  100. tr.test_eq(ra::Small<int, 3> {44, 77, 99}, b);
  101. }
  102. {
  103. int a = 0;
  104. ra::Small<int, 2, 2> b = where(a>=0, where(a>=1, 99, 77), 44);
  105. tr.test_eq(ra::Small<int, 2, 2> {77, 77, 77, 77}, b);
  106. }
  107. }
  108. tr.section("where, scalar W, array arguments in T/F");
  109. {
  110. double a = 1./7;
  111. ra::Small<double, 2> b {1, 2};
  112. ra::Small<double, 2> c = where(a>0, b, 3.);
  113. tr.test_eq(ra::Small<double, 2> {1, 2}, c);
  114. }
  115. tr.section("where as lvalue, scalar");
  116. {
  117. double a=0, b=0;
  118. bool w = true;
  119. where(w, a, b) = 99;
  120. tr.test_eq(a, 99);
  121. tr.test_eq(b, 0);
  122. where(!w, a, b) = 77;
  123. tr.test_eq(99, a);
  124. tr.test_eq(77, b);
  125. }
  126. tr.section("where, scalar + rank 0 array");
  127. {
  128. ra::Small<double> a { 33. };
  129. double b = 22.;
  130. tr.test_eq(33, double(where(true, a, b)));
  131. tr.test_eq(22, double(where(true, b, a)));
  132. }
  133. tr.section("where as lvalue, xpr [raop01]");
  134. {
  135. ra::Unique<int, 1> a { 0, 0, 0, 0 };
  136. ra::Unique<int, 1> b { 0, 0, 0, 0 };
  137. where(ra::_0>0 && ra::_0<3, a, b) = 7;
  138. tr.test_eq(ra::Unique<int, 1> {0, 7, 7, 0}, a);
  139. tr.test_eq(ra::Unique<int, 1> {7, 0, 0, 7}, b);
  140. where(ra::_0<=0 || ra::_0>=3, a, b) += 2;
  141. tr.test_eq(ra::Unique<int, 1> {2, 7, 7, 2}, a);
  142. tr.test_eq(ra::Unique<int, 1> {7, 2, 2, 7}, b);
  143. // Both must be lvalues; TODO check that either of these is an error.
  144. // where(ra::_0>0 && ra::_0<3, ra::_0, a) = 99;
  145. // where(ra::_0>0 && ra::_0<3, a, ra::_0) = 99;
  146. }
  147. tr.section("where with rvalue TensorIndex, fails to compile with g++ 5.2 -Os, gives wrong result with -O0");
  148. {
  149. tr.test_eq(ra::Small<int, 2> {0, 1},
  150. where(ra::Unique<bool, 1> { true, false }, ra::TensorIndex<0>(), ra::TensorIndex<0>()));
  151. tr.test_eq(ra::Unique<int, 1> { 0, 2 }, where(ra::Unique<bool, 1> { true, false }, 3*ra::_0, 2*ra::_0));
  152. }
  153. return tr.summary();
  154. }