arrays.scm 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. ; Helper functions for CBLAS bindings.
  2. ; (c) Daniel Llorens - 2014-2015, 2017, 2019, 2021
  3. ; This library is free software; you can redistribute it and/or modify it under
  4. ; the terms of the GNU Lesser General Public License as published by the Free
  5. ; Software Foundation; either version 3 of the License, or (at your option) any
  6. ; later version.
  7. ; This file is shared with (ffi blis arrays) in guile-ffi-cblis.
  8. (define-module (ffi cblas arrays)
  9. #:export (syntax->list
  10. srfi-4-type-size
  11. check-array check-2-arrays
  12. stride dim
  13. define-sdcz
  14. define-auto))
  15. (import (system foreign) (srfi srfi-1) (srfi srfi-11) (srfi srfi-26) (ice-9 match))
  16. (define (stride A i)
  17. (list-ref (shared-array-increments A) i))
  18. (define (dim A i)
  19. (list-ref (array-dimensions A) i))
  20. (define (check-array A rank type)
  21. (unless (= rank (array-rank A)) (throw 'bad-rank (array-rank A)))
  22. (unless (typed-array? A type) (throw 'bad-type type (array-type A))))
  23. (define (check-2-arrays A B rank type)
  24. (check-array A rank type)
  25. (check-array B rank type)
  26. (unless (= (array-length A) (array-length B))
  27. (throw 'bad-sizes (array-length A) (array-length B)))
  28. (unless (= 0 (caar (array-shape A)) (caar (array-shape B)))
  29. (throw 'bad-base-indices (array-length A) (array-length B))))
  30. (define (srfi-4-type-size stype)
  31. (case stype
  32. ((s8 u8 uv8) 1)
  33. ((s16 u16) 2)
  34. ((f32 s32 u32) 4)
  35. ((c32 f64 s64 u64) 8)
  36. ((c64) 16)
  37. (else (throw 'bad-srfi-4-type-type stype))))
  38. ; https://www.scheme.com/csug8/syntax.html §11.3
  39. (define syntax->list
  40. (lambda (ls)
  41. (syntax-case ls ()
  42. (() '())
  43. ((x . r) (cons #'x (syntax->list #'r))))))
  44. (eval-when (expand load eval)
  45. (define (subst-qmark stx-name t)
  46. (let* ((s (symbol->string (syntax->datum stx-name)))
  47. (i (string-index s #\?)))
  48. (datum->syntax stx-name (string->symbol (string-replace s (symbol->string t) i (+ i 1)))))))
  49. (define-syntax define-sdcz
  50. (lambda (x)
  51. (syntax-case x ()
  52. ((_ root n ...)
  53. (with-syntax ((definer (datum->syntax x (string->symbol (format #f "define-~a" (syntax->datum #'root))))))
  54. (cons #'begin
  55. (append-map
  56. (lambda (tag t)
  57. (let ((fun (map (cut subst-qmark <> t) (syntax->list #'(n ...)))))
  58. ; #`(quote #,(datum->syntax x tag)) to write out a symbol, but assembling docstrings seems harder (?)
  59. (list (cons* #'definer (datum->syntax x tag) fun)
  60. (cons* #'export fun))))
  61. '(f32 f64 c32 c64)
  62. '(s d c z))))))))
  63. (define-syntax define-auto
  64. (lambda (x)
  65. (syntax-case x ()
  66. ((_ (fun args ...) X ?fun)
  67. #`(begin
  68. (define (fun args ...)
  69. ((match (array-type X)
  70. #,@(map (lambda (tag t) (list #`(quote #,(datum->syntax x tag)) (subst-qmark #'?fun t)))
  71. '(f32 f64 c32 c64)
  72. '(s d c z)))
  73. args ...))
  74. (export fun))))))