123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- #include "simple/geom/vector.hpp"
- #include "simple/geom/bool_algebra.hpp"
- #include <fstream>
- #include <cassert>
- using namespace simple;
- using geom::vector;
- using int2 = vector<int,2>;
- using int3 = vector<int,3>;
- using int2x2 = vector<int2, 2>;
- using int3x3 = vector<int3, 3>;
- template <typename Vector>
- std::istream& operator>>(std::istream& is, Vector& v)
- {
- for(auto&& c : v)
- if(not (is >> c)) break;
- return is;
- }
- void SquareMatrixMultiplication()
- {
- std::vector<int3x3> matrices;
- std::ifstream test_data("square_matrix.data");
- int3x3 matrix;
- while(test_data >> matrix)
- matrices.push_back(matrix);
- assert(matrices.size() > 3);
- assert(matrices.size() % 3 == 0);
- for(auto i = matrices.begin(); i != matrices.end(); i+=3)
- {
- int3x3 A = *i;
- int3x3 B = *(i+1);
- int3x3 AxB = *(i+2);
- assert(( B(A) == AxB ));
- }
- }
- void MatrixVectorMultiplication()
- {
- struct test_case
- {
- int3x3 matrix;
- int3 in, out;
- };
- std::vector<test_case> tests;
- std::ifstream test_data("matrix_vector.data");
- while(test_data)
- {
- test_case test;
- test_data >> test.in;
- test_data >> test.matrix;
- test_data >> test.out;
- tests.push_back(test);
- }
- assert(tests.size() > 2);
- tests.pop_back();
- for(auto&& [matrix, in, out] : tests)
- assert( out == matrix(in) );
- }
- void DotProduct()
- {
- struct test_case
- {
- int3 in1, in2;
- int out;
- };
- std::vector<test_case> tests;
- std::ifstream test_data("dot_product.data");
- while(test_data)
- {
- test_case test;
- test_data >> test.in1;
- test_data >> test.in2;
- test_data >> test.out;
- tests.push_back(test);
- }
- assert(tests.size() > 2);
- tests.pop_back();
- for(auto&& [in1, in2, out] : tests)
- {
- assert( out == in1(in2) );
- assert( out == in2(in1) );
- }
- }
- void NonSquareMatrixMultiplication()
- {
- using int2x3 = vector<int2, 3>;
- using int3x2 = vector<int3, 2>;
- using int3x5 = vector<int3, 5>;
- using int2x5 = vector<int2, 5>;
- int2x3 a{ int2x3::array {{
- {1, 2},
- {2, 1},
- {1, 2},
- }}};
- int3x5 b{ int3x5::array {{
- {1, 2, 3},
- {3, 1, 2},
- {2, 3, 1},
- {3, 2, 1},
- {1, 3, 2}
- }}};
- int2x5 ans{ int2x5::array {{
- {8, 10},
- {7, 11},
- {9, 9},
- {8, 10},
- {9, 9}
- }}};
- assert ( ans == a(b) );
- struct test_case
- {
- int3x2 in1;
- int2x3 in2;
- int2x2 out;
- };
- std::vector<test_case> tests;
- std::ifstream test_data("matrix.data");
- while(test_data)
- {
- test_case test;
- test_data >> test.in1;
- test_data >> test.in2;
- test_data >> test.out;
- tests.push_back(test);
- }
- assert(tests.size() > 2);
- tests.pop_back();
- for(auto&& [in1, in2, out] : tests)
- assert( out == in2(in1) );
- }
- // TODO: all the other ops -_-
- void RowColumnVectorAndMatrix()
- {
- const vector row(0.1f, 0.2f, 0.3f);
- auto matrix = vector {
- vector(1.0f, 2.0f, 3.0f),
- vector(4.0f, 5.0f, 6.0f),
- vector(7.0f, 8.0f, 9.0f),
- };
- assert(( matrix + row ==
- vector{
- vector(1.1f, 2.2f, 3.3f),
- vector(4.1f, 5.2f, 6.3f),
- vector(7.1f, 8.2f, 9.3f),
- }
- ));
- assert(( row + matrix ==
- vector{
- vector(1.1f, 2.2f, 3.3f),
- vector(4.1f, 5.2f, 6.3f),
- vector(7.1f, 8.2f, 9.3f),
- }
- ));
- matrix += row;
- assert(( matrix ==
- vector{
- vector(1.1f, 2.2f, 3.3f),
- vector(4.1f, 5.2f, 6.3f),
- vector(7.1f, 8.2f, 9.3f),
- }
- ));
- const vector column{
- vector(0.1f),
- vector(0.2f),
- vector(0.3f),
- };
- matrix = vector {
- vector(1.0f, 2.0f, 3.0f),
- vector(4.0f, 5.0f, 6.0f),
- vector(7.0f, 8.0f, 9.0f),
- };
- assert(( matrix + column ==
- vector{
- vector(1.1f, 2.1f, 3.1f),
- vector(4.2f, 5.2f, 6.2f),
- vector(7.3f, 8.3f, 9.3f),
- }
- ));
- assert(( column + matrix ==
- vector{
- vector(1.1f, 2.1f, 3.1f),
- vector(4.2f, 5.2f, 6.2f),
- vector(7.3f, 8.3f, 9.3f),
- }
- ));
- matrix += column;
- assert(( matrix ==
- vector{
- vector(1.1f, 2.1f, 3.1f),
- vector(4.2f, 5.2f, 6.2f),
- vector(7.3f, 8.3f, 9.3f),
- }
- ));
- assert
- (
- vector
- (
- vector(10),
- vector(20),
- vector(30)
- )
- +
- vector(1,2,3)
- ==
- vector
- (
- vector(11, 12, 13),
- vector(21, 22, 23),
- vector(31, 32, 33)
- )
- );
- assert
- (
- vector(1,2,3)
- +
- vector
- (
- vector(10),
- vector(20),
- vector(30)
- )
- ==
- vector
- (
- vector(11, 12, 13),
- vector(21, 22, 23),
- vector(31, 32, 33)
- )
- );
- }
- void PolynomialMultiplication()
- {
- const vector p1(1, -1, 3, 2);
- const vector p2{
- vector(4),
- vector(2),
- vector(1),
- vector(-5),
- };
- // get a matrix all combination
- const auto all_combos = p1 * p2;
- constexpr auto degree = std::max(p1.dimensions, p2.dimensions);
- auto result = vector<int, degree + degree - 1>{};
- // sum the secondary diagonals of the matrix
- constexpr size_t x = -1;
- result += all_combos[0].mix<0,1,2,3,x,x,x>(0);
- result += all_combos[1].mix<x,0,1,2,3,x,x>(0);
- result += all_combos[2].mix<x,x,0,1,2,3,x>(0);
- result += all_combos[3].mix<x,x,x,0,1,2,3>(0);
- assert(result == vector(4,-2,11,8,12,-13,-10));
- }
- void StructredBinding()
- {
- {
- vector v(1,2,3);
- auto [a,b,c] = v;
- assert( a == 1);
- assert( b == 2);
- assert( c == 3);
- }
- {
- vector v(1,2,3);
- const auto [a,b,c] = v;
- assert( a == 1);
- assert( b == 2);
- assert( c == 3);
- }
- }
- constexpr bool Constexprness()
- {
- constexpr int3x3 A{}, B{};
- constexpr int3 a{}, b{};
- void(A(B)); void(B(A)); void(A(a)); void(B(a)); void(a(b));
- return true;
- }
- int main()
- {
- SquareMatrixMultiplication();
- MatrixVectorMultiplication();
- DotProduct();
- NonSquareMatrixMultiplication();
- RowColumnVectorAndMatrix();
- PolynomialMultiplication();
- static_assert(Constexprness());
- return 0;
- }
|