matrix-multiplication.scm 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. (use-modules
  2. (ffi blis)
  3. (srfi srfi-1)
  4. (ice-9 match))
  5. ;; Build a data abstraction.
  6. (define matrix-shape
  7. (lambda (mat)
  8. (array-shape mat)))
  9. (define matrix-dimensions
  10. (lambda (mat)
  11. (array-dimensions mat)))
  12. (define get-dims-rows
  13. (lambda (dims)
  14. (car dims)))
  15. (define get-dims-cols
  16. (lambda (dims)
  17. (cadr dims)))
  18. ;; "For matrix multiplication, the number of columns in the first matrix must be
  19. ;; equal to the number of rows in the second matrix."
  20. ;; (https://en.wikipedia.org/wiki/Matrix_multiplication)
  21. (define matrix-multiply!
  22. (lambda (mat-a mat-b mat-res)
  23. "Multiply mat-a and mat-b and store the result in mat-res. Return mat-res."
  24. (let ([alpha 1.0] [beta 1.0])
  25. ;; Now we can make use of the library functions.
  26. (gemm! BLIS_NO_TRANSPOSE BLIS_NO_TRANSPOSE
  27. alpha
  28. mat-a mat-b
  29. beta
  30. mat-res)
  31. mat-res)))
  32. (define simple-matrix-multiply
  33. (lambda (mat-a mat-b)
  34. "Calculate the product of 2 matrices, automatically creating another matrix
  35. with the correct dimensions, which the result will be writen to."
  36. (define mat-a-dim (matrix-dimensions mat-a))
  37. (define mat-b-dim (matrix-dimensions mat-b))
  38. (matrix-multiply! mat-a
  39. mat-b
  40. ;; The result of a matrix multiplication A x B has the
  41. ;; shape of (rows of A, columns of B). The library demands,
  42. ;; that we give a matrix to write the result to as another
  43. ;; argument.
  44. (make-typed-array 'f64
  45. *unspecified*
  46. (get-dims-rows mat-a-dim)
  47. (get-dims-cols mat-b-dim)))))
  48. (display
  49. (simple-format
  50. #f "~s\n"
  51. (simple-matrix-multiply
  52. ;; First we create some arrays, which will be used as matrices. We create an
  53. ;; array filled with numbers of type 'f64, which is a float of 64 bits, which
  54. ;; is usually called a double.
  55. (list->typed-array 'f64
  56. '(0 0)
  57. '((1 1 1)
  58. (1 1 1)
  59. (1 1 1)
  60. (1 1 1)))
  61. (list->typed-array 'f64
  62. '(0 0)
  63. '((1 1 1 1)
  64. (1 1 1 1)
  65. (1 1 1 1))))))