123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- (define-module (pruning))
- (use-modules
- ((srfi srfi-1) #:prefix srfi1:)
- (tree)
- (metrics)
- (dataset)
- (prediction)
- (utils list)
- (utils display))
- (define-public count-leaves
- (lambda (tree)
- (cond [(leaf-node? tree) 1]
- [else (+ (count-leaves (node-left tree))
- (count-leaves (node-right tree)))])))
- (define-public traverse-collect-last-split-nodes
- (lambda (subtree)
- (cond
- [(leaf-node? subtree) empty-dataset]
- [(last-split-node? subtree) (list subtree)]
- [(leaf-node? (node-left subtree))
- (traverse-collect-last-split-nodes (node-right subtree))]
- [(leaf-node? (node-right subtree))
- (traverse-collect-last-split-nodes (node-left subtree))]
- [else
- (append (traverse-collect-last-split-nodes (node-left subtree))
- (traverse-collect-last-split-nodes (node-right subtree)))])))
- (define-public get-last-split-nodes
- (lambda (tree)
- (flatten (traverse-collect-last-split-nodes tree))))
- (define-public select-better-tree
- (lambda (tree
- pruned-tree
- pruning-set
- feature-column-indices
- label-column-index
- accuracy-tolerance)
- "Prune the tree so that the accuracy of the tree is best for the given
- pruning set."
- (let ([actual-labels
- (dataset-get-col pruning-set label-column-index)]
- [tree-predicted-labels
- (predict-dataset tree pruning-set label-column-index)]
- [pruned-tree-predicted-labels
- (predict-dataset pruned-tree pruning-set label-column-index)])
- (let ([tree-accuracy
- (accuracy-metric actual-labels tree-predicted-labels)]
- [pruned-tree-accuracy
- (accuracy-metric actual-labels pruned-tree-predicted-labels)])
- #;(displayln (string-append "accuracy tree: " (number->string tree-accuracy)))
- #;(displayln (string-append "accuracy pruned-tree: " (number->string pruned-tree-accuracy)))
- (cond
- [(< (abs (- tree-accuracy pruned-tree-accuracy)) accuracy-tolerance)
- pruned-tree]
- [else tree])))))
- (define-public prune-node-from-tree
- (lambda (tree split-node)
- (cond [(leaf-node? tree) tree]
- [(equal? tree split-node)
- (make-leaf-node-from-split-node tree)]
- [else
- (make-node
- ;; copy all data
- (node-data tree)
- (node-split-feature-index tree)
- (node-split-value tree)
- ;; prune subtrees
- ;; FUTURE TODO: This is up for multicore optimization. Each subtree
- ;; pruning can run as a separate job.
- (prune-node-from-tree (node-left tree) split-node)
- (prune-node-from-tree (node-right tree) split-node))])))
- (define-public prune-with-pruning-set
- (lambda* (tree
- pruning-set
- feature-column-indices
- label-column-index
- #:key
- (tolerance 0.0))
- (define iter-split-nodes
- (lambda (tree remaining-split-nodes)
- (cond [(null? remaining-split-nodes) tree]
- [else
- #;(displayln "REMAINING-SPLIT-NODES:")
- #;(displayln remaining-split-nodes)
- (iter-split-nodes
- (select-better-tree tree
- (prune-node-from-tree tree
- (srfi1:first remaining-split-nodes))
- pruning-set
- feature-column-indices
- label-column-index
- tolerance)
- (cdr remaining-split-nodes))])))
- (define iter-trees
- (lambda (tree tree-leaves#)
- (let* ([pruned-tree (iter-split-nodes tree (get-last-split-nodes tree))]
- [pruned-tree-leaves# (count-leaves pruned-tree)])
- ;;(displayln "tree: ") (displayln tree)
- ;;(displayln "pruned tree: ") (displayln pruned-tree)
- (cond
- ;; in the previous call to iter-split-nodes leaves were removed
- ;; by pruning the tree. This means that all last split nodes cannot
- ;; be removed and thus we finished the pruning process.
- [(= pruned-tree-leaves# tree-leaves#)
- (displayln "STOPPING CONDITION (PRUNING): pruning further would decrease accuracy beyong tolerance")
- tree]
- ;; in the last call to iter-split-nodes leaves were removed,
- ;; so there is at least one new last split node and we need
- ;; to try to prune that
- [else
- (displayln "CONTINUING PRUNING: tree lost nodes in previous iteration of pruning")
- (iter-trees pruned-tree pruned-tree-leaves#)]))))
- (iter-trees tree (count-leaves tree))))
|