decision-tree.scm 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. (define-module (decision-tree))
  2. (use-modules
  3. ;; SRFI-1 for list procedures
  4. ((srfi srfi-1) #:prefix srfi1:)
  5. ;; SRFI-8 for `receive` form
  6. (srfi srfi-8)
  7. (utils csv)
  8. (utils display)
  9. (utils math)
  10. (utils string)
  11. (utils list)
  12. (dataset)
  13. (data-point)
  14. (tree)
  15. (metrics)
  16. (prediction)
  17. (split-quality-measure))
  18. (define FILE-PATH
  19. "data_banknote_authentication.csv")
  20. ;; For each column we define a column converter, which converts the string,
  21. ;; which is read in from the CSV, to an appropriate data type for the data set
  22. ;; in the program.
  23. (define COLUMN-CONVERTERS
  24. (list (list string->number)
  25. (list string->number)
  26. (list string->number)
  27. (list string->number)
  28. (list string->number)))
  29. ;; Using the defined column converters, we define the data set.
  30. (define banking-dataset
  31. (all-rows "data_banknote_authentication.csv" #:converters COLUMN-CONVERTERS))
  32. ;; This is an artefact from development. It serves as an example to test things
  33. ;; with interactively or in a shorter time than with a whole larger data set.
  34. (define dev-dataset
  35. (list #(2.771244718 1.784783929 0)
  36. #(1.728571309 1.169761413 0)
  37. #(3.678319846 2.81281357 0)
  38. #(3.961043357 2.61995032 0)
  39. #(2.999208922 2.209014212 0)
  40. #(7.497545867 3.162953546 1)
  41. #(9.00220326 3.339047188 1)
  42. #(7.444542326 0.476683375 1)
  43. #(10.12493903 3.234550982 1)
  44. #(6.642287351 3.319983761 1)))
  45. ;; =======================
  46. ;; DECISION TREE ALGORITHM
  47. ;; =======================
  48. (define-public split-data
  49. (lambda (data index value)
  50. (receive (part1 part2)
  51. (dataset-partition (lambda (data-point)
  52. (< (data-point-get-col data-point index) value))
  53. data)
  54. (list part1 part2))))
  55. (define-public select-min-cost-split
  56. (lambda (split-a split-b)
  57. (if (< (split-cost split-a) (split-cost split-b))
  58. split-a
  59. split-b)))
  60. (define-public get-best-split-for-column
  61. (lambda* (data
  62. label-column-index
  63. column-index
  64. #:key
  65. (split-quality-proc gini-index))
  66. "Calculate the best split value for the column of the data at the given
  67. index."
  68. (let ([initial-placeholder-split
  69. (make-split 0 +inf.0 (list '() '()) +inf.0)])
  70. ;; TODO: Parallelism: This is a place, where parallelism could be made use
  71. ;; of. Instead of going through all the split values of the column
  72. ;; sequentially, the split values can be processed in parallel.
  73. (let iter-col-vals ([column-data (dataset-get-col data column-index)]
  74. [previous-best-split initial-placeholder-split])
  75. (cond
  76. [(dataset-column-empty? column-data) previous-best-split]
  77. [else
  78. (let* ([current-value (dataset-column-first column-data)]
  79. [current-subsets (split-data data
  80. column-index
  81. current-value)]
  82. [current-cost (split-quality-proc current-subsets label-column-index)])
  83. (iter-col-vals
  84. (dataset-column-rest column-data)
  85. (select-min-cost-split
  86. previous-best-split
  87. ;; FUTURE TODO: Here we are creating a Split record, which might
  88. ;; not be needed and thrown away after this iteration. An
  89. ;; optimization might be to not even create it, if the current
  90. ;; cost is higher than the cost of the previously best
  91. ;; split. However, always handling multiple values bloates the
  92. ;; code a little and the current implementation seems more
  93. ;; readable.
  94. (make-split column-index
  95. current-value
  96. current-subsets
  97. current-cost))))])))))
  98. (define-public get-best-split
  99. (lambda* (data
  100. feature-column-indices
  101. label-column-index
  102. #:key
  103. (split-quality-proc gini-index))
  104. (let ([max-col-index (- (data-point-length (dataset-first data)) 1)]
  105. [start-column-index 0]
  106. [initial-placeholder-split (make-split 0 +inf.0 (list '() '()) +inf.0)])
  107. ;; iterate over columns -- which column is best for splitting?
  108. ;; TODO: Parallelism: Here we could use multiple cores to calculate the
  109. ;; best split for different columns in parallel.
  110. (let iter-col-ind ([col-index start-column-index]
  111. [best-split-so-far initial-placeholder-split])
  112. (cond
  113. [(> col-index max-col-index) best-split-so-far]
  114. [(= col-index label-column-index)
  115. (iter-col-ind (+ col-index 1) best-split-so-far)]
  116. [else
  117. ;; iterate over values in 1 column -- which value is the best split
  118. ;; value?
  119. (iter-col-ind (+ col-index 1)
  120. (select-min-cost-split
  121. best-split-so-far
  122. (get-best-split-for-column
  123. data
  124. label-column-index
  125. col-index
  126. #:split-quality-proc split-quality-proc)))])))))
  127. (define-public fit
  128. (lambda* (#:key
  129. train-data
  130. (feature-column-indices '())
  131. label-column-index
  132. (max-depth 6)
  133. (min-data-points 12)
  134. (min-data-points-ratio 0.02)
  135. (min-impurity-split (expt 10 -7))
  136. (stop-at-no-impurity-improvement #t))
  137. (define all-data-length (dataset-length train-data))
  138. (define current-depth 1)
  139. #|
  140. STOP CRITERIA:
  141. - only one class in a subset (cannot be split any further and does not need to be split)
  142. - maximum tree depth reached
  143. - minimum number of data points in a subset
  144. - minimum ratio of data points in this subset
  145. |#
  146. (define all-same-label?
  147. (lambda (subset)
  148. ;; FUTURE TODO: Do no longer assume, that the label column is always an
  149. ;; integer or a number.
  150. (column-uniform? (dataset-get-col subset label-column-index) =)))
  151. (define insufficient-data-points-for-split?
  152. (lambda (subset)
  153. (let ([number-of-data-points (dataset-length subset)])
  154. (or (<= number-of-data-points min-data-points)
  155. (< number-of-data-points 2)))))
  156. (define max-depth-reached?
  157. (lambda (current-depth)
  158. (>= current-depth max-depth)))
  159. (define insufficient-data-points-ratio-for-split?
  160. (lambda (subset)
  161. (<= (/ (dataset-length subset) all-data-length) min-data-points-ratio)))
  162. (define no-improvement?
  163. (lambda (previous-split-impurity split-impurity)
  164. (and (<= previous-split-impurity split-impurity)
  165. stop-at-no-impurity-improvement)))
  166. (define insufficient-impurity?
  167. (lambda (impurity)
  168. (< impurity min-impurity-split)))
  169. #|
  170. Here we do the recursive splitting.
  171. |#
  172. (define recursive-split
  173. (lambda (subset current-depth previous-split-impurity)
  174. (display "recursive split on depth: ") (displayln current-depth)
  175. ;; Before splitting further, we check for stopping early conditions.
  176. ;; TODO: Refactor this part. This cond form is way to big. Think of
  177. ;; something clever. TODO: Parallelism: This might be a place to use
  178. ;; parallelism at, to check for the stopping criteria in
  179. ;; parallel. However, I think they might not take that long to calculate
  180. ;; anyway and the question is, whether the overhead is worth it.
  181. (cond
  182. [(max-depth-reached? current-depth)
  183. (displayln "STOPPING CONDITION: maximum depth")
  184. (displayln (string-append "INFO: still got "
  185. (number->string (dataset-length subset))
  186. " data points"))
  187. (make-leaf-node subset)]
  188. [(insufficient-data-points-for-split? subset)
  189. (displayln "STOPPING CONDITION: insuficient number of data points")
  190. (displayln (string-append "INFO: still got "
  191. (number->string (dataset-length subset))
  192. " data points"))
  193. (make-leaf-node subset)]
  194. [(insufficient-data-points-ratio-for-split? subset)
  195. (displayln "STOPPING CONDITION: insuficient ratio of data points")
  196. (displayln (string-append "INFO: still got "
  197. (number->string (dataset-length subset))
  198. " data points"))
  199. (make-leaf-node subset)]
  200. [(all-same-label? subset)
  201. (displayln "STOPPING CONDITION: all same label")
  202. (displayln (string-append "INFO: still got "
  203. (number->string (dataset-length subset))
  204. " data points"))
  205. (make-leaf-node subset)]
  206. [else
  207. (displayln (string-append "INFO: CONTINUING SPLITT: still got "
  208. (number->string (dataset-length subset))
  209. " data points"))
  210. ;; (display "input data for searching best split:") (displayln subset)
  211. (let* ([best-split
  212. (get-best-split subset
  213. feature-column-indices
  214. label-column-index
  215. #:split-quality-proc gini-index)])
  216. (cond
  217. [(no-improvement? previous-split-impurity (split-cost best-split))
  218. (displayln (string-append "STOPPING CONDITION: "
  219. "no improvement in impurity: previously: "
  220. (number->string previous-split-impurity) " "
  221. "now: "
  222. (number->string (split-cost best-split))))
  223. (make-leaf-node subset)]
  224. [(insufficient-impurity? previous-split-impurity)
  225. (displayln "STOPPING CONDITION: not enough impurity for splitting further")
  226. (make-leaf-node subset)]
  227. [else
  228. ;; Here are the recursive calls. This is not tail recursive, but
  229. ;; since the data structure itself is recursive and we only have
  230. ;; as many procedure calls as there are branches in the tree, it
  231. ;; is OK to not be tail recursive here.
  232. ;; TODO: Parallelism: Here is an obvious place to introduce
  233. ;; parallelism. The recursive calls to ~recursive-split~ can run
  234. ;; in parallel.
  235. (make-node subset
  236. (split-feature-index best-split)
  237. (split-value best-split)
  238. (recursive-split (car (split-subsets best-split))
  239. (+ current-depth 1)
  240. (split-cost best-split))
  241. (recursive-split (cadr (split-subsets best-split))
  242. (+ current-depth 1)
  243. (split-cost best-split)))]))])))
  244. (recursive-split train-data 1 1.0)))
  245. (define-public cross-validation-split
  246. (lambda* (dataset n-folds #:key (random-seed #f))
  247. (let* ([shuffled-dataset (shuffle-dataset dataset #:seed random-seed)]
  248. [number-of-data-points (dataset-length shuffled-dataset)]
  249. [fold-size
  250. (exact-floor (/ number-of-data-points n-folds))])
  251. (split-into-chunks-of-size-n shuffled-dataset
  252. (exact-ceiling
  253. (/ number-of-data-points n-folds))))))
  254. (define-public leave-one-out-k-folds
  255. (lambda (folds left-out-fold)
  256. (define leave-one-out-filter-procedure
  257. (lambda (fold)
  258. (not (equal? fold left-out-fold))))
  259. (filter leave-one-out-filter-procedure
  260. folds)))
  261. ;; evaluates the algorithm using cross validation split with n folds
  262. (define-public evaluate-algorithm
  263. (lambda* (#:key
  264. dataset
  265. n-folds
  266. feature-column-indices
  267. label-column-index
  268. (max-depth 6)
  269. (min-data-points 12)
  270. (min-data-points-ratio 0.02)
  271. (min-impurity-split (expt 10 -7))
  272. (stop-at-no-impurity-improvement #t)
  273. (random-seed #f))
  274. "Calculate a list of accuracy values, one value for each fold of a
  275. cross-validation split."
  276. ;; FUTURE TODO: Parallelism: This is up for multicore optimization, instead
  277. ;; of sequentially going through the folds in order. It should be relatively
  278. ;; simple to calculate the accuracy for each fold in a separate job.
  279. (let ([folds
  280. (cross-validation-split dataset
  281. n-folds
  282. #:random-seed random-seed)])
  283. (let iter ([remaining-folds folds])
  284. (cond
  285. [(null? remaining-folds) '()]
  286. [else
  287. (let ([fold (car remaining-folds)])
  288. (cons (let* ([train-set
  289. (fold-right append
  290. empty-dataset
  291. (leave-one-out-k-folds folds fold))]
  292. [test-set
  293. (map (lambda (data-point)
  294. (data-point-take-features data-point
  295. label-column-index))
  296. fold)]
  297. [actual-labels (dataset-get-col fold label-column-index)]
  298. [tree
  299. (fit #:train-data train-set
  300. #:feature-column-indices feature-column-indices
  301. #:label-column-index label-column-index
  302. #:max-depth max-depth
  303. #:min-data-points min-data-points
  304. #:min-data-points-ratio min-data-points-ratio
  305. #:min-impurity-split min-impurity-split
  306. #:stop-at-no-impurity-improvement stop-at-no-impurity-improvement)]
  307. [predicted-labels
  308. (predict-dataset tree test-set label-column-index)])
  309. (accuracy-metric actual-labels predicted-labels))
  310. (iter (cdr remaining-folds))))])))))