tensor-indices.cc 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. // -*- mode: c++; coding: utf-8 -*-
  2. /// @file tensor-indices.cc
  3. /// @brief Tests for TensorIndex.
  4. // (c) Daniel Llorens - 2019-2020
  5. // This library is free software; you can redistribute it and/or modify it under
  6. // the terms of the GNU Lesser General Public License as published by the Free
  7. // Software Foundation; either version 3 of the License, or (at your option) any
  8. // later version.
  9. #include <iostream>
  10. #include "ra/bench.hh"
  11. #include "ra/ra.hh"
  12. #include "ra/test.hh"
  13. using std::cout, std::endl, std::flush, ra::TestRecorder;
  14. int main()
  15. {
  16. TestRecorder tr;
  17. // rank 1
  18. {
  19. ra::Big<int, 1> a = {0, 0, 0};
  20. ra::ply(map([](auto && i) { std::cout << "i: " << i << std::endl; },
  21. a+ra::TensorIndex<0> {}));
  22. ra::ply_ravel(map([](auto && i) { std::cout << "i: " << i << std::endl; },
  23. a+ra::TensorIndex<0> {}));
  24. }
  25. // rank 2
  26. {
  27. ra::Big<int, 2> a = {{0, 0, 0}, {0, 0, 0}};
  28. ra::ply(map([](auto && i, auto && j) { std::cout << "i: " << i << ", " << j << std::endl; },
  29. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  30. ra::ply_ravel(map([](auto && i, auto && j) { std::cout << "i: " << i << ", " << j << std::endl; },
  31. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  32. }
  33. // benchmark
  34. auto taking_view =
  35. [](TestRecorder & tr, auto && a)
  36. {
  37. auto fa = [&a]()
  38. {
  39. int c = 0;
  40. ra::ply(ra::map([&c](auto && i, auto && j) { c += 2*i-j; },
  41. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  42. return c;
  43. };
  44. auto fb = [&a]()
  45. {
  46. int c = 0;
  47. ra::ply_ravel(ra::map([&c](auto && i, auto && j) { c += 2*i-j; },
  48. a+ra::TensorIndex<0> {}, a+ra::TensorIndex<1> {}));
  49. return c;
  50. };
  51. tr.test_eq(499500000, fa());
  52. tr.test_eq(499500000, fb());
  53. auto bench = Benchmark {/* repeats */ 30, /* runs */ 30};
  54. bench.info("vala").report(std::cout, bench.run(fa), 1e-6);
  55. bench.info("valb").report(std::cout, bench.run(fb), 1e-6);
  56. };
  57. ra::Big<int, 2> const a({1000, 1000}, 0);
  58. taking_view(tr, a);
  59. auto b = transpose<1, 0>(a);
  60. taking_view(tr, b);
  61. return tr.summary();
  62. }