tensor-indices.C 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file tensor-indices.C
  3. /// @brief Compare TensorIndex with test/old.H:OldTensorIndex that required ply_index.
  4. // (c) Daniel Llorens - 2019
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #include <iostream>
  10. #include "ra/bench.H"
  11. #include "ra/operators.H"
  12. #include "ra/io.H"
  13. #include "ra/test.H"
  14. #include "test/old.H"
  15. using std::cout, std::endl, std::flush;
  16. int main()
  17. {
  18. TestRecorder tr;
  19. // tests for old ra::OldTensorIndex
  20. {
  21. tr.info("by_index(OldTensorIndex)").test(ra::by_index<decltype(ra::OldTensorIndex<0> {})>);
  22. tr.info("by_index(Expr)").test(ra::by_index<decltype(ra::OldTensorIndex<0> {}+ra::OldTensorIndex<1> {})>);
  23. }
  24. {
  25. ra::Unique<int, 2> a({3, 2}, ra::none);
  26. auto dyn = ra::expr([](int & a, int b) { a = b; }, a.iter(), ra::OldTensorIndex<0> {});
  27. static_assert(ra::by_index<decltype(dyn)>, "bad by_index test 1");
  28. ply_index(dyn);
  29. tr.test_eq(ra::_0, a);
  30. }
  31. {
  32. ra::Unique<int, 2> a({3, 2}, ra::none);
  33. auto dyn = ra::expr([](int & dest, int const & src) { dest = src; }, a.iter(), ra::OldTensorIndex<0> {});
  34. static_assert(ra::by_index<decltype(dyn)>, "bad by_index test 2");
  35. ply_index(dyn);
  36. tr.test_eq(ra::_0, a);
  37. }
  38. // rank 1
  39. {
  40. ra::Big<int, 1> a = {0, 0, 0};
  41. ra::ply_index(map([](auto && i) { std::cout << "i: " << i << std::endl; },
  42. a+ra::OldTensorIndex<0> {}));
  43. ra::ply_index(map([](auto && i) { std::cout << "i: " << i << std::endl; },
  44. a+ra::TensorIndex<0> {}));
  45. ra::ply_ravel(map([](auto && i) { std::cout << "i: " << i << std::endl; },
  46. a+ra::TensorIndex<0> {}));
  47. }
  48. // rank 2
  49. {
  50. ra::Big<int, 2> a = {{0, 0, 0}, {0, 0, 0}};
  51. ra::ply_index(map([](auto && i, auto && j) { std::cout << "i: " << i << ", " << j << std::endl; },
  52. a+ra::OldTensorIndex<0> {}, a+ra::OldTensorIndex<1> {}));
  53. ra::ply_index(map([](auto && i, auto && j) { std::cout << "i: " << i << ", " << j << std::endl; },
  54. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  55. ra::ply_ravel(map([](auto && i, auto && j) { std::cout << "i: " << i << ", " << j << std::endl; },
  56. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  57. }
  58. // benchmark
  59. auto taking_view =
  60. [](TestRecorder & tr, auto && a)
  61. {
  62. auto fa = [&a]()
  63. {
  64. int c = 0;
  65. ra::ply_index(ra::map([&c](auto && i, auto && j) { c += 2*i-j; },
  66. a+ra::OldTensorIndex<0> {}, a+ra::OldTensorIndex<1> {}));
  67. return c;
  68. };
  69. auto fb = [&a]()
  70. {
  71. int c = 0;
  72. ra::ply_index(ra::map([&c](auto && i, auto && j) { c += 2*i-j; },
  73. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  74. return c;
  75. };
  76. auto fc = [&a]()
  77. {
  78. int c = 0;
  79. ra::ply_ravel(ra::map([&c](auto && i, auto && j) { c += 2*i-j; },
  80. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  81. return c;
  82. };
  83. tr.test_eq(499500000, fa());
  84. tr.test_eq(499500000, fb());
  85. tr.test_eq(499500000, fc());
  86. auto bench = Benchmark {/* repeats */ 30, /* runs */ 30};
  87. bench.info("vala").report(std::cout, bench.run(fa), 1e-6);
  88. bench.info("valb").report(std::cout, bench.run(fb), 1e-6);
  89. bench.info("valc").report(std::cout, bench.run(fc), 1e-6);
  90. };
  91. ra::Big<int, 2> const a({1000, 1000}, 0);
  92. taking_view(tr, a);
  93. auto b = transpose<1, 0>(a);
  94. taking_view(tr, b);
  95. return tr.summary();
  96. }