matrix.h 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #ifndef __sti__matrix_h__
  2. #define __sti__matrix_h__
  3. /*
  4. Arbitrary sized 2-dimensional float matrix operations.
  5. For a more extensive 4x4 and 3x3 matrix library (for 3d games),
  6. see my other repo c3dlas.
  7. Unlike most of sti, the matrix library is explicitly prefixed
  8. with "sti_" because "matrix" is such a popular data type name
  9. in other libraries. I assume that there are people who will
  10. want to cherry pick the sti matrix files, or who want to use
  11. sti with their own matrix functions but don't want to have to
  12. think about removing the sti versions.
  13. */
  14. #include <stdio.h> // fprintf
  15. typedef struct sti_matrix {
  16. int c, r;
  17. float data[0];
  18. } sti_matrix;
  19. sti_matrix* sti_matrix_new(int c, int r);
  20. sti_matrix* sti_matrix_same_size(sti_matrix* m);
  21. sti_matrix* sti_matrix_size_for_mul(sti_matrix* a, sti_matrix* b);
  22. sti_matrix* sti_matrix_copy(sti_matrix* m);
  23. // careful here...
  24. void sti_matrix_print(sti_matrix* m, FILE* f);
  25. void sti_matrix_clear(sti_matrix* m);
  26. void sti_matrix_set(sti_matrix* m, float v);
  27. void sti_matrix_load(sti_matrix* m, float* v);
  28. void sti_matrix_ident(sti_matrix* m);
  29. void sti_matrix_rand(sti_matrix* m, float min, float max);
  30. void sti_matrix_transpose(sti_matrix* a, sti_matrix* out);
  31. int sti_matrix_eq(sti_matrix* a, sti_matrix* b);
  32. // returns a newly allocated matrix of the proper size
  33. sti_matrix* sti_matrix_mul(sti_matrix* a, sti_matrix* b);
  34. // no checks for size match.
  35. void sti_matrix_mulp(sti_matrix* a, sti_matrix* b, sti_matrix* out);
  36. // multiplies a with the transpose of b
  37. sti_matrix* sti_matrix_mul_transb(sti_matrix* a, sti_matrix* b);
  38. // no checks for size match.
  39. void sti_matrix_mulp_transb(sti_matrix* a, sti_matrix* b, sti_matrix* out);
  40. void sti_matrix_add(sti_matrix* a, sti_matrix* b, sti_matrix* out);
  41. void sti_matrix_sub(sti_matrix* a, sti_matrix* b, sti_matrix* out);
  42. void sti_matrix_scalar_mul(sti_matrix* a, sti_matrix* b, sti_matrix* out);
  43. void sti_matrix_scale(sti_matrix* a, float s, sti_matrix* out);
  44. // apply e^a[n]
  45. void sti_matrix_exp(sti_matrix* a, sti_matrix* out);
  46. // simple flat sum of all values in the matrix
  47. float sti_matrix_sum(sti_matrix* a);
  48. void sti_matrix_softmax(sti_matrix* a, sti_matrix* out);
  49. void sti_matrix_min(sti_matrix* a, float minval, sti_matrix* out);
  50. void sti_matrix_max(sti_matrix* a, float maxval, sti_matrix* out);
  51. void sti_matrix_clamp(sti_matrix* a, float minval, float maxval, sti_matrix* out);
  52. void sti_matrix_relu_0(sti_matrix* a, sti_matrix* out);
  53. void sti_matrix_relu_half(sti_matrix* a, sti_matrix* out);
  54. void sti_matrix_relu_n(sti_matrix* a, float n, sti_matrix* out);
  55. // mean squared error: SUM( (a - b)^2 )
  56. float sti_matrix_mse(sti_matrix* a, sti_matrix* b);
  57. #endif // __sti__matrix_h__