123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- (define-module (decision-tree))
- (use-modules
- ;; SRFI-1 for list procedures
- ((srfi srfi-1) #:prefix srfi1:)
- ;; SRFI-8 for `receive` form
- (srfi srfi-8)
- (utils csv)
- (utils display)
- (utils math)
- (utils string)
- (utils list)
- (dataset)
- (data-point)
- (tree)
- (metrics)
- (prediction)
- (split-quality-measure)
- ;; custom parallelism module
- (parallelism)
- ;; R6RS exception handling using conditions
- (rnrs exceptions)
- (rnrs conditions))
- ;; =======================
- ;; DECISION TREE ALGORITHM
- ;; =======================
- (define-public split-data
- (lambda (data index value)
- (receive (part1 part2)
- (dataset-partition (lambda (data-point)
- (< (data-point-get-col data-point index) value))
- data)
- (list part1 part2))))
- (define-public select-min-cost-split
- (lambda (. splits)
- (cond
- [(null? splits)
- (raise
- (condition
- (make-error)
- (make-message-condition
- "cannot get minimum cost split given no splits")
- (make-irritants-condition splits)
- (make-who-condition 'splits)))]
- [else
- (let iter ([remaining-splits (cdr splits)]
- [prev-min-split (car splits)])
- (cond
- [(null? remaining-splits) prev-min-split]
- [else
- (let ([next-split (car remaining-splits)])
- (if (< (split-cost prev-min-split) (split-cost next-split))
- (iter (cdr remaining-splits) prev-min-split)
- (iter (cdr remaining-splits) next-split)))]))])))
- (define-public get-best-split-for-column
- (lambda* (data
- label-column-index
- column-index
- #:key
- (split-quality-proc gini-index))
- "Calculate the best split value for the column of the data at the given
- index. This is achieved by going through all values in the column and
- calculating a split for each value and finding the one with the minimum cost."
- ;; FUTURE TODO: Allow for a heuristic, which selects a few split values at
- ;; random, or in other ways. Then check only the costs of those splits and
- ;; find the split of minimum cost in those few splits. The selected split
- ;; values do not necessarily have to be values of the split feature of any
- ;; data point in the data set. They could for example also be values
- ;; dividing the range of values of the split feature perfectly. Such a
- ;; heuristic might not result in a perfect tree, but would be much faster
- ;; then trying all the values the split feature takes on in the data set.
- (let ([initial-placeholder-split
- ;; The initial split is a dummy split, which has the worst cost
- ;; possible: Positively infinite cost.
- (make-split 0 +inf.0 (list '() '()) +inf.0)])
- ;; TODO: Parallelism: This is a place, where parallelism could be made use
- ;; of. Instead of going through all the split values of the column
- ;; sequentially, the split values can be processed in parallel. However,
- ;; it might be too much overhead to calculate the split for each split
- ;; value in a separate calculation unit. One might want to specify an
- ;; additional argument, which defines for how many split values each
- ;; calculation unit calculates the result and keep the overhead
- ;; configurable..
- (let iter-col-vals ([column-data (dataset-get-col data column-index)]
- [previous-best-split initial-placeholder-split])
- (cond
- [(dataset-column-empty? column-data) previous-best-split]
- [else
- (let* ([current-value (dataset-column-first column-data)]
- [current-subsets (split-data data
- column-index
- current-value)]
- [current-cost (split-quality-proc current-subsets label-column-index)])
- (iter-col-vals
- (dataset-column-rest column-data)
- (select-min-cost-split
- previous-best-split
- ;; FUTURE TODO: Here we are creating a Split record, which might
- ;; not be needed and thrown away after this iteration. An
- ;; optimization might be to not even create it, if the current
- ;; cost is higher than the cost of the previously best
- ;; split. However, always handling multiple values bloates the
- ;; code a little and the current implementation seems more
- ;; readable.
- (make-split column-index
- current-value
- current-subsets
- current-cost))))])))))
- (define-public get-best-split
- (lambda* (data
- feature-column-indices
- label-column-index
- #:key
- (split-quality-proc gini-index))
- (let ([max-col-index (- (data-point-length (dataset-first data)) 1)]
- [start-column-index 0]
- [initial-placeholder-split (make-split 0 +inf.0 (list '() '()) +inf.0)])
- (apply select-min-cost-split
- ;; NOTE: parallelism
- (run-in-parallel
- (lambda (column-index)
- (get-best-split-for-column data
- label-column-index
- column-index
- #:split-quality-proc split-quality-proc))
- feature-column-indices)))))
- (define-public fit
- (lambda* (#:key
- train-data
- (feature-column-indices '())
- label-column-index
- (max-depth 6)
- (min-data-points 12)
- (min-data-points-ratio 0.02)
- (min-impurity-split (expt 10 -7))
- (stop-at-no-impurity-improvement #t))
- (define all-data-length (dataset-length train-data))
- (define current-depth 1)
- #|
- STOP CRITERIA:
- - only one class in a subset (cannot be split any further and does not need to be split)
- - maximum tree depth reached
- - minimum number of data points in a subset
- - minimum ratio of data points in this subset
- |#
- (define all-same-label?
- (lambda (subset)
- (displayln "checking for stop condition: all-same-label?")
- ;; FUTURE TODO: Do no longer assume, that the label column is always an
- ;; integer or a number.
- (column-uniform? (dataset-get-col subset label-column-index) =)))
- (define insufficient-data-points-for-split?
- (lambda (subset)
- (displayln "checking for stop condition: insufficient-data-points-for-split?")
- (let ([number-of-data-points (dataset-length subset)])
- (or (<= number-of-data-points min-data-points)
- (< number-of-data-points 2)))))
- (define max-depth-reached?
- (lambda (current-depth)
- (displayln "checking for stop condition: max-depth-reached?")
- (>= current-depth max-depth)))
- (define insufficient-data-points-ratio-for-split?
- (lambda (subset)
- (displayln "checking for stop condition: insufficient-data-points-ratio-for-split?")
- (<= (/ (dataset-length subset) all-data-length) min-data-points-ratio)))
- (define no-improvement?
- (lambda (previous-split-impurity split-impurity)
- (displayln "checking for stop condition: no-improvement?")
- (and (<= previous-split-impurity split-impurity)
- stop-at-no-impurity-improvement)))
- (define insufficient-impurity?
- (lambda (impurity)
- (displayln "checking for stop condition: insufficient-impurity?")
- (< impurity min-impurity-split)))
- #|
- Here we do the recursive splitting.
- |#
- (define recursive-split
- (lambda (subset current-depth previous-split-impurity)
- (display "recursive split on depth: ") (displayln current-depth)
- ;; Before splitting further, we check for stopping early conditions.
- ;; TODO: Refactor this part. This cond form is way to big. Think of
- ;; something clever. TODO: Parallelism: This might be a place to use
- ;; parallelism at, to check for the stopping criteria in
- ;; parallel. However, I think they might not take that long to calculate
- ;; anyway and the question is, whether the overhead is worth it.
- (displayln "will check for stop conditions now")
- (cond
- [(max-depth-reached? current-depth)
- (displayln "STOPPING CONDITION: maximum depth")
- (displayln (string-append "INFO: still got "
- (number->string (dataset-length subset))
- " data points"))
- (make-leaf-node subset)]
- [(insufficient-data-points-for-split? subset)
- (displayln "STOPPING CONDITION: insuficient number of data points")
- (displayln (string-append "INFO: still got "
- (number->string (dataset-length subset))
- " data points"))
- (make-leaf-node subset)]
- [(insufficient-data-points-ratio-for-split? subset)
- (displayln "STOPPING CONDITION: insuficient ratio of data points")
- (displayln (string-append "INFO: still got "
- (number->string (dataset-length subset))
- " data points"))
- (make-leaf-node subset)]
- [(all-same-label? subset)
- (displayln "STOPPING CONDITION: all same label")
- (displayln (string-append "INFO: still got "
- (number->string (dataset-length subset))
- " data points"))
- (make-leaf-node subset)]
- [else
- (displayln (string-append "INFO: CONTINUING SPLITT: still got "
- (number->string (dataset-length subset))
- " data points"))
- ;; (display "input data for searching best split:") (displayln subset)
- (let* ([best-split
- (get-best-split subset
- feature-column-indices
- label-column-index
- #:split-quality-proc gini-index)])
- (cond
- [(no-improvement? previous-split-impurity (split-cost best-split))
- (displayln (string-append "STOPPING CONDITION: "
- "no improvement in impurity: previously: "
- (number->string previous-split-impurity) " "
- "now: "
- (number->string (split-cost best-split))))
- (make-leaf-node subset)]
- [(insufficient-impurity? previous-split-impurity)
- (displayln "STOPPING CONDITION: not enough impurity for splitting further")
- (make-leaf-node subset)]
- [else
- ;; Here are the recursive calls. This is not tail recursive, but
- ;; since the data structure itself is recursive and we only have
- ;; as many procedure calls as there are branches in the tree, it
- ;; is OK to not be tail recursive here.
- (let ([subsets
- ;; NOTE: parallelism
- (run-in-parallel (lambda (subset)
- (recursive-split subset
- (+ current-depth 1)
- (split-cost best-split)))
- (list (car (split-subsets best-split))
- (cadr (split-subsets best-split))))])
- (make-node subset
- (split-feature-index best-split)
- (split-value best-split)
- (car subsets)
- (cadr subsets)))]))])))
- (recursive-split train-data 1 1.0)))
- (define-public cross-validation-split
- (lambda* (dataset n-folds #:key (random-seed #f))
- (let* ([shuffled-dataset (shuffle-dataset dataset #:seed random-seed)]
- [number-of-data-points (dataset-length shuffled-dataset)]
- [fold-size
- (exact-floor (/ number-of-data-points n-folds))])
- (split-into-chunks-of-size-n shuffled-dataset
- (exact-ceiling
- (/ number-of-data-points n-folds))))))
- (define-public leave-one-out-k-folds
- (lambda (folds left-out-fold)
- (define leave-one-out-filter-procedure
- (lambda (fold)
- (not (equal? fold left-out-fold))))
- (filter leave-one-out-filter-procedure
- folds)))
- ;; evaluates the algorithm using cross validation split with n folds
- (define-public evaluate-algorithm
- (lambda* (#:key
- dataset
- n-folds
- feature-column-indices
- label-column-index
- (max-depth 6)
- (min-data-points 12)
- (min-data-points-ratio 0.02)
- (min-impurity-split (expt 10 -7))
- (stop-at-no-impurity-improvement #t)
- (random-seed #f))
- "Calculate a list of accuracy values, one value for each fold of a
- cross-validation split."
- (let ([folds
- (cross-validation-split dataset
- n-folds
- #:random-seed random-seed)])
- ;; NOTE: parallelism
- (run-in-parallel
- (lambda (fold)
- (let* ([train-set
- (fold-right append
- empty-dataset
- (leave-one-out-k-folds folds fold))]
- [test-set
- (map (lambda (data-point)
- (data-point-take-features data-point
- label-column-index))
- fold)]
- [actual-labels (dataset-get-col fold label-column-index)]
- [tree
- (fit #:train-data train-set
- #:feature-column-indices feature-column-indices
- #:label-column-index label-column-index
- #:max-depth max-depth
- #:min-data-points min-data-points
- #:min-data-points-ratio min-data-points-ratio
- #:min-impurity-split min-impurity-split
- #:stop-at-no-impurity-improvement stop-at-no-impurity-improvement)]
- [predicted-labels
- (predict-dataset tree test-set label-column-index)])
- (accuracy-metric actual-labels predicted-labels)))
- folds))))
|