reshape.C 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file reshape.C
  3. /// @brief Tests for reshape().
  4. // (c) Daniel Llorens - 2017
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #include "ra/operators.H"
  10. #include "ra/io.H"
  11. #include "ra/test.H"
  12. #include "ra/mpdebug.H"
  13. #include <memory>
  14. using std::cout, std::endl;
  15. namespace ra {
  16. std::ostream & operator<<(std::ostream & o, ra::Dim const d)
  17. {
  18. o << "{" << d.size << ", " << d.stride << "}";
  19. return o;
  20. }
  21. } // namespace ra
  22. int main()
  23. {
  24. TestRecorder tr(std::cout);
  25. tr.section("reshape");
  26. {
  27. ra::Big<int, 3> aa({2, 3, 3}, ra::_0*3+ra::_1);
  28. auto a = aa(ra::all, ra::all, 0);
  29. tr.info("ravel_free").test_eq(ra::iota(6), ravel_free(a));
  30. tr.test_eq(ra::scalar(a.p), ra::scalar(ravel_free(a).p));
  31. // select.
  32. tr.info("reshape select").test_eq(ra::Big<int, 1> {0, 1, 2}, reshape(a, ra::Small<int, 1> {3}));
  33. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 1> {3}).p));
  34. // tile.
  35. auto tilea = reshape(a, ra::Small<int, 3> {2, 2, 3});
  36. tr.info("reshape select").test_eq(ra::Big<int, 3>({2, 2, 3}, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}), tilea);
  37. tr.info("some tile-reshapes are free (I)").test_eq(0, tilea.stride(0));
  38. tr.info("some tile-reshapes are free (II)").test_eq(ra::scalar(a.data()), ra::scalar(tilea.data()));
  39. // reshape with free ravel
  40. tr.info("reshape w/free ravel I").test_eq(ra::Big<int, 2>({3, 2}, {0, 1, 2, 3, 4, 5}), reshape(a, ra::Small<int, 2> {3, 2}));
  41. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 2> {3, 2}).p));
  42. tr.info("reshape w/free ravel II").test_eq(ra::Big<int, 3>({2, 1, 2}, {0, 1, 2, 3}), reshape(a, ra::Small<int, 3> {2, 1, 2}));
  43. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 3> {2, 1, 2}).p));
  44. tr.info("reshape w/free ravel III").test_eq(ra::Big<int, 2>({3, 2}, {0, 1, 2, 3, 4, 5}), reshape(a, ra::Small<int, 2> {-1, 2}));
  45. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 2> {-1, 2}).p));
  46. tr.info("reshape w/free ravel IV").test_eq(ra::Big<int, 2>({2, 3}, {0, 1, 2, 3, 4, 5}), reshape(a, ra::Small<int, 2> {2, -1}));
  47. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 2> {2, -1}).p));
  48. tr.info("reshape w/free ravel V").test_eq(ra::Big<int, 3>({2, 1, 3}, {0, 1, 2, 3, 4, 5}), reshape(a, ra::Small<int, 3> {2, -1, 3}));
  49. tr.test_eq(ra::scalar(a.p), ra::scalar(reshape(a, ra::Small<int, 3> {2, -1, 3}).p));
  50. }
  51. tr.section("reshape from var rank to fixed rank");
  52. {
  53. ra::Big<int> a({2, 3}, ra::_0*3+ra::_1);
  54. auto b = reshape(a, ra::Small<int, 1> {3});
  55. tr.info("reshape select").test_eq(ra::Big<int, 1> {0, 1, 2}, b);
  56. tr.test_eq(ra::scalar(a.p), ra::scalar(b.p));
  57. tr.info("reshape can fix rank").test_eq(1, ra::ra_traits<decltype(b)>::rank_s());
  58. }
  59. tr.section("reshape from var rank to fixed rank using the initializer_list shim");
  60. {
  61. ra::Big<int> a({2, 3}, ra::_0*3+ra::_1);
  62. auto b = reshape(a, {3, 2});
  63. tr.info("reshape").test_eq(ra::Big<int, 2> {{0, 1}, {2, 3}, {4, 5}}, b);
  64. tr.test_eq(ra::scalar(a.p), ra::scalar(b.p));
  65. tr.info("reshape can return fixed rank (2)").test_eq(2, ra::ra_traits<decltype(b)>::rank_s());
  66. auto c = reshape(a, {3l, 2l}); // check deduction works regardless
  67. tr.info("reshape").test_eq(ra::Big<int, 2> {{0, 1}, {2, 3}, {4, 5}}, c);
  68. tr.test_eq(ra::scalar(a.p), ra::scalar(c.p));
  69. tr.info("reshape can return fixed rank (3)").test_eq(2, ra::ra_traits<decltype(c)>::rank_s());
  70. }
  71. tr.section("reshape from var rank to var rank");
  72. {
  73. ra::Big<int> a({2, 3}, ra::_0*3+ra::_1);
  74. auto b = reshape(a, ra::Big<int, 1> {3});
  75. tr.info("reshape select").test_eq(ra::Big<int, 1> {0, 1, 2}, b);
  76. tr.test_eq(ra::scalar(a.p), ra::scalar(b.p));
  77. tr.info("reshape can return var rank (1)").test_eq(ra::RANK_ANY, ra::ra_traits<decltype(b)>::rank_s());
  78. }
  79. tr.section("reshape to fixed rank to var rank");
  80. {
  81. // FIXME warning w/ gcc 6.3 in bootstrap.H inside() [ra32]. Apparent root is in decl of b in reshape().
  82. ra::Big<int, 2> a({2, 3}, ra::_0*3+ra::_1);
  83. auto b = reshape(a, ra::Big<int, 1> {3});
  84. tr.info("reshape select").test_eq(ra::Big<int, 1> {0, 1, 2}, b);
  85. tr.test_eq(ra::scalar(a.p), ra::scalar(b.p));
  86. tr.info("reshape can return var rank").test_eq(ra::RANK_ANY, ra::ra_traits<decltype(b)>::rank_s());
  87. }
  88. tr.section("conversion from var rank to fixed rank");
  89. {
  90. ra::Big<int> a({2, 3}, ra::_0*3+ra::_1);
  91. ra::View<int, 2> b = a;
  92. tr.info("fixing rank").test_eq(ra::_0*3+ra::_1, b);
  93. tr.info("fixing rank is view").test(a.data()==b.data());
  94. }
  95. return tr.summary();
  96. }