12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- (define-module (prediction))
- (use-modules
- ;; SRFI-8 for `receive`
- (srfi srfi-8)
- (dataset)
- (data-point)
- (tree))
- (define-public predict
- (lambda (tree data-point label-column-index)
- #;(displayln tree)
- (cond [(leaf-node? tree)
- (dataset-majority-prediction (node-data tree) label-column-index)]
- [else
- (cond [(< (data-point-get-col data-point (node-split-feature-index tree))
- (node-split-value tree))
- (predict (node-left tree) data-point label-column-index)]
- [else (predict (node-right tree) data-point label-column-index)])])))
- #;(define-public node-majority-prediction
- (lambda (node label-column-index)
- (dataset-majority-prediction (node-data node) label-column-index)))
- (define-public predict-dataset
- (lambda (tree data label-column-index)
- (dataset-map
- (lambda (data-point) (predict tree data-point label-column-index))
- data)))
- (define-public predict-at-leaf-node
- (lambda (leaf label-column-index)
- (dataset-majority-prediction (node-data leaf) label-column-index)))
- (define-public dataset-majority-prediction
- (lambda (data label-column-index)
- (receive (part1 part2)
- (dataset-partition
- (lambda (data-point)
- (= (data-point-get-col data-point label-column-index)
- 0))
- data)
- (cond [(> (dataset-length part2) (dataset-length part1)) 1]
- [else 0]))))
- #;(define-public predict-dataset
- (lambda (tree dataset label-column-index)
- (let iter ([remaining-dataset dataset])
- (cond
- [(dataset-empty? remaining-dataset) '()]
- [else
- (cons (predict tree (car remaining-dataset) label-column-index)
- (iter (cdr remaining-dataset)))]))))
|