prediction.scm 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. (define-module (prediction))
  2. (use-modules
  3. ;; SRFI-8 for `receive`
  4. (srfi srfi-8)
  5. (dataset)
  6. (data-point)
  7. (tree))
  8. (define-public predict
  9. (lambda (tree data-point label-column-index)
  10. #;(displayln tree)
  11. (cond [(leaf-node? tree)
  12. (dataset-majority-prediction (node-data tree) label-column-index)]
  13. [else
  14. (cond [(< (data-point-get-col data-point (node-split-feature-index tree))
  15. (node-split-value tree))
  16. (predict (node-left tree) data-point label-column-index)]
  17. [else (predict (node-right tree) data-point label-column-index)])])))
  18. #;(define-public node-majority-prediction
  19. (lambda (node label-column-index)
  20. (dataset-majority-prediction (node-data node) label-column-index)))
  21. (define-public predict-dataset
  22. (lambda (tree data label-column-index)
  23. (dataset-map
  24. (lambda (data-point) (predict tree data-point label-column-index))
  25. data)))
  26. (define-public predict-at-leaf-node
  27. (lambda (leaf label-column-index)
  28. (dataset-majority-prediction (node-data leaf) label-column-index)))
  29. (define-public dataset-majority-prediction
  30. (lambda (data label-column-index)
  31. (receive (part1 part2)
  32. (dataset-partition
  33. (lambda (data-point)
  34. (= (data-point-get-col data-point label-column-index)
  35. 0))
  36. data)
  37. (cond [(> (dataset-length part2) (dataset-length part1)) 1]
  38. [else 0]))))
  39. #;(define-public predict-dataset
  40. (lambda (tree dataset label-column-index)
  41. (let iter ([remaining-dataset dataset])
  42. (cond
  43. [(dataset-empty? remaining-dataset) '()]
  44. [else
  45. (cons (predict tree (car remaining-dataset) label-column-index)
  46. (iter (cdr remaining-dataset)))]))))