matrix.c 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #include "matrix.h"
  2. #include <assert.h>
  3. #include <math.h>
  4. #include <stdlib.h> // malloc
  5. #include <string.h> // memset
  6. #define M(m, _c, _r) m->data[(_c) + (_r) * m->c]
  7. sti_matrix* sti_matrix_new(int c, int r) {
  8. sti_matrix* mat;
  9. mat = malloc(sizeof(*mat) + sizeof(mat->data[0]) * r * c);
  10. mat->r = r;
  11. mat->c = c;
  12. return mat;
  13. }
  14. sti_matrix* sti_matrix_same_size(sti_matrix* m) {
  15. return sti_matrix_new(m->c, m->r);
  16. }
  17. sti_matrix* sti_matrix_size_for_mul(sti_matrix* a, sti_matrix* b) {
  18. return sti_matrix_new(b->c, a->r);
  19. }
  20. sti_matrix* sti_matrix_copy(sti_matrix* m) {
  21. sti_matrix* mat = sti_matrix_same_size(m);
  22. memcpy(mat->data, m->data, sizeof(*mat->data) * m->r * m->c);
  23. return mat;
  24. }
  25. // careful here...
  26. void sti_matrix_print(sti_matrix* m, FILE* f) {
  27. for(long r = 0; r < m->r; r++) {
  28. for(long c = 0; c < m->c; c++) {
  29. fprintf(f, "%.2f ", m->data[c + m->c * r]);
  30. }
  31. fprintf(f, "\n");
  32. }
  33. }
  34. void sti_matrix_clear(sti_matrix* m) {
  35. memset(m->data, 0, sizeof(m->data) * m->c * m->r);
  36. }
  37. void sti_matrix_set(sti_matrix* m, float v) {
  38. if(v == 0) {
  39. memset(m->data, 0, sizeof(m->data) * m->c * m->r);
  40. return;
  41. }
  42. long sz = m->c * m->r;
  43. for(int i = 0; i < sz; i++) {
  44. m->data[i] = v;
  45. }
  46. }
  47. void sti_matrix_load(sti_matrix* m, float* v) {
  48. memcpy(m->data, v, sizeof(m->data[0]) * m->c * m->r);
  49. }
  50. void sti_matrix_ident(sti_matrix* m) {
  51. for(int i = 0; i < m->c; i++)
  52. for(int j = 0; j < m->r; j++) {
  53. m->data[i + j * m->c] = i == j;
  54. }
  55. }
  56. void sti_matrix_rand(sti_matrix* m, float min, float max) {
  57. long len = m->c * m->r;
  58. float sz = max - min;
  59. for(long n = 0; n < len; n++) {
  60. float x = ((float)rand() * sz) / (float)RAND_MAX;
  61. m->data[n] = min + x;
  62. }
  63. }
  64. void sti_matrix_transpose(sti_matrix* a, sti_matrix* out) {
  65. assert(a->c * a->r <= out->c * out->r);
  66. out->r = a->c;
  67. out->c = a->r;
  68. for(int r = 0; r < a->r; r++)
  69. for(int c = r; c < a->c; c++) {
  70. float tmp;
  71. if(c < a->c) tmp = M(a, c, r);
  72. if(c < out->c) M(out, c, r) = M(a, r, c);
  73. if(c < a->c) M(out, r, c) = tmp;
  74. }
  75. }
  76. int sti_matrix_eq(sti_matrix* a, sti_matrix* b) {
  77. if(a->r != b->r || a->c != b->c) return 0;
  78. long len = a->c * a->r;
  79. for(long n = 0; n < len; n++) {
  80. if(a->data[n] != b->data[n]) return 0;
  81. }
  82. return 1;
  83. }
  84. sti_matrix* sti_matrix_mul(sti_matrix* a, sti_matrix* b) {
  85. sti_matrix* o;
  86. if(a->c != b->r) return NULL;
  87. o = sti_matrix_new(b->c, a->r);
  88. sti_matrix_mulp(a, b, o);
  89. return o;
  90. }
  91. // no checks for size match.
  92. void sti_matrix_mulp(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  93. long klim = a->c;
  94. for(int c = 0; c < b->c; c++)
  95. for(int r = 0; r < a->r; r++) {
  96. M(out, c, r) = 0;
  97. for(int k = 0; k < klim; k++) {
  98. M(out, c, r) += M(a, k, r) * M(b, c, k);
  99. }
  100. }
  101. }
  102. // multiplies a with the transpose of b
  103. sti_matrix* sti_matrix_mul_transb(sti_matrix* a, sti_matrix* b) {
  104. sti_matrix* o;
  105. if(a->c != b->c) return NULL;
  106. o = sti_matrix_new(b->r, a->r);
  107. sti_matrix_mulp_transb(a, b, o);
  108. return o;
  109. }
  110. void sti_matrix_mulp_transb(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  111. long klim = a->c;
  112. for(int c = 0; c < b->r; c++)
  113. for(int r = 0; r < a->r; r++) {
  114. M(out, c, r) = 0;
  115. for(int k = 0; k < klim; k++) {
  116. M(out, c, r) += M(a, k, r) * M(b, k, c);
  117. }
  118. }
  119. }
  120. #define MIN(a, b) (a < b ? a : b)
  121. void sti_matrix_add(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  122. int c = MIN(out->c, MIN(a->c, b->c));
  123. int r = MIN(out->r, MIN(a->r, b->r));
  124. for(int j = 0; j < r; j++)
  125. for(int i = 0; i < c; i++) {
  126. M(out, i, j) = M(a, i, j) + M(a, i, j);
  127. }
  128. }
  129. void sti_matrix_sub(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  130. int c = MIN(out->c, MIN(a->c, b->c));
  131. int r = MIN(out->r, MIN(a->r, b->r));
  132. for(int j = 0; j < r; j++)
  133. for(int i = 0; i < c; i++) {
  134. M(out, i, j) = M(a, i, j) - M(a, i, j);
  135. }
  136. }
  137. void sti_matrix_scalar_mul(sti_matrix* a, sti_matrix* b, sti_matrix* out) {
  138. long sz = a->c * a->r;
  139. for(int i = 0; i < sz; i++) {
  140. out->data[i] = a->data[i] * b->data[i];
  141. }
  142. }
  143. void sti_matrix_scale(sti_matrix* a, float s, sti_matrix* out) {
  144. long sz = a->c * a->r;
  145. for(int i = 0; i < sz; i++) {
  146. out->data[i] = a->data[i] * s;
  147. }
  148. }
  149. // apply e^a[n]
  150. void sti_matrix_exp(sti_matrix* a, sti_matrix* out) {
  151. long sz = a->c * a->r;
  152. for(int i = 0; i < sz; i++) {
  153. out->data[i] = expf(a->data[i]);
  154. }
  155. }
  156. // simple flat sum of all values in the matrix
  157. float sti_matrix_sum(sti_matrix* a) {
  158. long sz = a->c * a->r;
  159. float sum = 0;
  160. for(int i = 0; i < sz; i++) {
  161. sum += a->data[i];
  162. }
  163. return sum;
  164. }
  165. void sti_matrix_softmax(sti_matrix* a, sti_matrix* out) {
  166. long sz = a->c * a->r;
  167. float sum = 0;
  168. for(int i = 0; i < sz; i++) {
  169. out->data[i] = expf(a->data[i]);
  170. sum += out->data[i];
  171. }
  172. float invsum = 1.0 / sum;
  173. for(int i = 0; i < sz; i++) {
  174. out->data[i] *= invsum;
  175. }
  176. }
  177. void sti_matrix_min(sti_matrix* a, float minval, sti_matrix* out) {
  178. long sz = a->c * a->r;
  179. for(int i = 0; i < sz; i++) {
  180. out->data[i] = fminf(a->data[i], minval);
  181. }
  182. }
  183. void sti_matrix_max(sti_matrix* a, float maxval, sti_matrix* out) {
  184. long sz = a->c * a->r;
  185. for(int i = 0; i < sz; i++) {
  186. out->data[i] = fmaxf(a->data[i], maxval);
  187. }
  188. }
  189. void sti_matrix_clamp(sti_matrix* a, float minval, float maxval, sti_matrix* out) {
  190. long sz = a->c * a->r;
  191. for(int i = 0; i < sz; i++) {
  192. out->data[i] = fminf(minval, fmaxf(a->data[i], maxval));
  193. }
  194. }
  195. void sti_matrix_tanh_clamp(sti_matrix* a, sti_matrix* out) {
  196. long sz = a->c * a->r;
  197. for(int i = 0; i < sz; i++) {
  198. out->data[i] = tanhf(a->data[i]);
  199. }
  200. }
  201. void sti_matrix_relu_0(sti_matrix* a, sti_matrix* out) {
  202. long sz = a->c * a->r;
  203. for(int i = 0; i < sz; i++) {
  204. out->data[i] = fmax(0, a->data[i]);
  205. }
  206. }
  207. void sti_matrix_relu_half(sti_matrix* a, sti_matrix* out) {
  208. long sz = a->c * a->r;
  209. for(int i = 0; i < sz; i++) {
  210. out->data[i] = fmax(0, a->data[i] - .5f) + .5f;
  211. }
  212. }
  213. void sti_matrix_relu_n(sti_matrix* a, float n, sti_matrix* out) {
  214. long sz = a->c * a->r;
  215. for(int i = 0; i < sz; i++) {
  216. out->data[i] = fmax(0, a->data[i] - n) + n;
  217. }
  218. }
  219. void sti_matrix_silu(sti_matrix* a, sti_matrix* out) {
  220. long sz = a->c * a->r;
  221. for(int i = 0; i < sz; i++) {
  222. out->data[i] = a->data[i] / (1.f + expf(-a->data[i]));
  223. }
  224. }
  225. // mean squared error: SUM( (a - b)^2 )
  226. float sti_matrix_mse(sti_matrix* a, sti_matrix* b) {
  227. long sz = a->c * a->r;
  228. float sum = 0;
  229. for(int i = 0; i < sz; i++) {
  230. float x = a->data[i] - b->data[i];
  231. sum += x * x;
  232. }
  233. return sum / sz;
  234. }