train-model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # train-model.py
  5. #
  6. # Copyright 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #
  8. print("Loading imports...")
  9. import os
  10. import tensorflow as tf
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import shutil
  14. import time
  15. import cv2
  16. import math
  17. import subprocess
  18. from tqdm import tqdm
  19. from sklearn.metrics import confusion_matrix, classification_report
  20. from models import createHarlowModel, simpleModel, inceptionV3Model, mediumModel
  21. from keras import callbacks
  22. print("Done!")
  23. LOADER_DIRECTORY = os.path.normpath("../animal-crossing-loader/")
  24. TRAIN_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "train")
  25. VAL_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "val")
  26. TEST_DIRECTORY = os.path.join(LOADER_DIRECTORY, "dataset", "test")
  27. CLASS_BOBCAT = 0
  28. CLASS_COYOTE = 1
  29. CLASS_DEER = 2
  30. CLASS_ELK = 3
  31. CLASS_HUMAN = 4
  32. CLASS_NOT_INTERESTING = 5
  33. CLASS_RACCOON = 6
  34. CLASS_WEASEL = 7
  35. CLASS_BOBCAT_STRING = "bobcat"
  36. CLASS_COYOTE_STRING = "coyote"
  37. CLASS_DEER_STRING = "deer"
  38. CLASS_ELK_STRING = "elk"
  39. CLASS_HUMAN_STRING = "human"
  40. CLASS_RACCOON_STRING = "raccoon"
  41. CLASS_WEASEL_STRING = "weasel"
  42. CLASS_NOT_INTERESTING_STRING = "not"
  43. CLASS_NAMES_LIST_INT = [CLASS_BOBCAT, CLASS_COYOTE, CLASS_DEER, CLASS_ELK, CLASS_HUMAN, CLASS_NOT_INTERESTING, CLASS_RACCOON, CLASS_WEASEL]
  44. CLASS_NAMES_LIST_STR = [CLASS_BOBCAT_STRING, CLASS_COYOTE_STRING, CLASS_DEER_STRING, CLASS_ELK_STRING, CLASS_HUMAN_STRING, CLASS_NOT_INTERESTING_STRING, CLASS_RACCOON_STRING, CLASS_WEASEL_STRING]
  45. TEST_PRINTING = False
  46. # ~ IMG_WIDTH = 40
  47. # ~ IMG_HEIGHT = 30
  48. # ~ IMG_WIDTH = 100
  49. # ~ IMG_HEIGHT = 100
  50. IMG_WIDTH = 200
  51. IMG_HEIGHT = 150
  52. # ~ IMG_WIDTH = 400
  53. # ~ IMG_HEIGHT = 300
  54. # ~ IMG_WIDTH = 300
  55. # ~ IMG_HEIGHT = 225
  56. IMG_CHANNELS = 3
  57. IMG_SHAPE_TUPPLE = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
  58. # ~ BATCH_SIZE = 8 #This is also set in the image loader. They must match.
  59. BATCH_SIZE = 32 #This is also set in the image loader. They must match.
  60. # ~ EPOCHS = 20
  61. # ~ EPOCHS = 100
  62. EPOCHS = 2
  63. PATIENCE = 10
  64. REPEATS = 5
  65. #how to get programatically?
  66. MY_PYTHON_STRING = "python"
  67. # ~ MY_PYTHON_STRING = "python3"
  68. # ~ MY_PYTHON_STRING = "py"
  69. def main(args):
  70. listOfFoldersToDELETE = []
  71. deleteDirectories(listOfFoldersToDELETE)
  72. #base folder for this run
  73. ts = time.localtime()
  74. timeStr = "./%d-%d-%d-%d-%d-%d/" % (ts.tm_year, ts.tm_mon, ts.tm_mday, ts.tm_hour, ts.tm_min, ts.tm_sec)
  75. timeStr = os.path.normpath(timeStr)
  76. # Folders to save model tests
  77. simpleFolder = os.path.join(timeStr, "simple")
  78. harlowFolder = os.path.join(timeStr, "harlow")
  79. inceptionFolder = os.path.join(timeStr, "incpetionV3")
  80. mediumFolder = os.path.join(timeStr, "medium")
  81. modelBaseFolders = [simpleFolder, mediumFolder, harlowFolder, inceptionFolder] #Same order as the modelList below!
  82. # ~ modelBaseFolders = [mediumFolder] #Same order as the modelList below!
  83. makeDirectories(modelBaseFolders)
  84. # train_ds is for training the model.
  85. # val_ds is for validation during training.
  86. # test_ds is a dataset of unmodified images for testing the model after training.
  87. train_ds, val_ds, test_ds = getDatasets(TRAIN_DIRECTORY, VAL_DIRECTORY, TEST_DIRECTORY)
  88. if TEST_PRINTING:
  89. printSample(test_ds)
  90. imgShape = IMG_SHAPE_TUPPLE
  91. batchSize = BATCH_SIZE
  92. numEpochs = EPOCHS
  93. numPatience = PATIENCE
  94. #these contain the functions to create the models, NOT the models themselves.
  95. modelList = [simpleModel, mediumModel, createHarlowModel, inceptionV3Model]
  96. # ~ modelList = [simpleModel, mediumModel]
  97. # ~ modelList = [mediumModel]
  98. #This loop can be segmented further. We could also keep track of the
  99. #best accuracy from each type of model. Then printout which model
  100. #gave the best accuracy overall and say where the model is saved.
  101. for i in range(len(modelList)):
  102. overallBestAcc = -math.inf
  103. overallBestModel = None
  104. overallBestFolder = ""
  105. eachModelAcc = []
  106. thisAcc, thisModel, thisFolder = \
  107. runManyTests(
  108. modelBaseFolders[i], REPEATS, modelList[i], \
  109. train_ds, val_ds, test_ds, numEpochs, \
  110. numPatience, imgShape, batchSize, LOADER_DIRECTORY)
  111. eachModelAcc.append(thisAcc)
  112. if thisAcc > overallBestAcc:
  113. overallBestAcc = thisAcc
  114. overallBestModel = thisModel
  115. overallBestFolder = thisFolder
  116. else:
  117. del thisModel
  118. outString = "The best accuracies among the models..." + "\n"
  119. for thingy in eachModelAcc:
  120. outString += str(round(thingy, 4)) + "\n"
  121. outString += "The overall best saved model is in folder: " + overallBestFolder + "\n"
  122. outString += "It has an accuracy of: " + str(round(overallBestAcc, 4)) + "\n"
  123. print(outString)
  124. printStringToFile(os.path.join(timeStr, "overall-output.txt") , outString, "w")
  125. print("A winner is YOU!")
  126. return 0
  127. def runManyTests(thisBaseOutFolder, numRepeats, inputModel, train_ds, val_ds, test_ds, numEpochs, numPatience, imgShapeTupple, batchSize, loaderScriptDirectory):
  128. saveCopyOfSourceCode(thisBaseOutFolder)
  129. theRunWithTheBestAccuracy = -1
  130. theBestAccuracy = -math.inf
  131. theBestModel = None
  132. theBestSavedModelFolder = "" #might not need this if I use the lists.
  133. #akshually if we save to disk each time we can save ram.
  134. eachTestAcc = []
  135. for jay in range(numRepeats):
  136. reloadImageDatasets(loaderScriptDirectory, "load-dataset.py") ## this function could be replaced with a shuffle function. If we had one big dataset file, we could shuffle that instead of reloading the images every time. But this works.
  137. thisInputModel = inputModel(imgShapeTupple)
  138. thisTestAcc, thisOutModel, thisOutputFolder = runOneTest( \
  139. thisInputModel, os.path.join(thisBaseOutFolder, str(jay)), \
  140. train_ds, val_ds, test_ds, \
  141. numEpochs, numPatience, imgShapeTupple, \
  142. batchSize)
  143. eachTestAcc.append(thisTestAcc)
  144. if thisTestAcc > theBestAccuracy:
  145. theBestAccuracy = thisTestAcc
  146. theRunWithTheBestAccuracy = jay
  147. theBestModel = thisOutModel
  148. theBestSavedModelFolder = thisOutputFolder
  149. else:
  150. del thisInputModel #To save a bit of ram faster.
  151. outString = "The accuracies for this run..." + "\n"
  152. for thingy in eachTestAcc:
  153. outString += str(round(thingy, 4)) + "\n"
  154. outString += "The best saved model is in folder: " + theBestSavedModelFolder + "\n"
  155. outString += "It has an accuracy of: " + str(round(theBestAccuracy, 4)) + "\n"
  156. print(outString)
  157. printStringToFile(os.path.join(thisBaseOutFolder, "repeats-output.txt") , outString, "w")
  158. return theBestAccuracy, theBestModel, theBestSavedModelFolder
  159. def runOneTest(thisModel, thisOutputFolder, train_ds, val_ds, test_ds, numEpochs, numPatience, imgShapeTupple, batchSize):
  160. thisModel.summary()
  161. print("Training model: " + thisOutputFolder)
  162. thisCheckpointFolder = os.path.join(thisOutputFolder, "checkpoint")
  163. thisMissclassifiedFolder = os.path.join(thisOutputFolder, "misclassifed-images")
  164. foldersForThisModel = [thisOutputFolder, thisCheckpointFolder, thisMissclassifiedFolder]
  165. makeDirectories(foldersForThisModel)
  166. myHistory = trainModel(thisModel, train_ds, val_ds, thisCheckpointFolder, numEpochs, numPatience)
  167. print("Creating graphs of training history...")
  168. #thisTestAcc is the same as strAcc but in unrounded float form.
  169. strAcc, strLoss, thisTestAcc = saveGraphs(thisModel, myHistory, test_ds, thisOutputFolder)
  170. #workin on this.
  171. stringToPrint = "Epochs: " + str(numEpochs) + "\n"
  172. stringToPrint += "Image Shape: " + str(imgShapeTupple) + "\n\n"
  173. stringToPrint += evaluateLabels(test_ds, thisModel, thisOutputFolder, thisMissclassifiedFolder, batchSize)
  174. stringToPrint += "Accuracy and loss according to tensorflow model.evaluate():\n"
  175. stringToPrint += strAcc + "\n"
  176. stringToPrint += strLoss + "\n"
  177. statFileName = os.path.join(thisOutputFolder, "stats.txt")
  178. printStringToFile(statFileName, stringToPrint, "w")
  179. print(stringToPrint)
  180. return thisTestAcc, thisModel, thisOutputFolder
  181. #Reload the images from the dataset so that you can run another test with randomized images.
  182. def reloadImageDatasets(loaderPath, scriptName):
  183. #save current directory
  184. startDirectory = os.getcwd()
  185. os.chdir(loaderPath)
  186. loaderPID = None
  187. # ~ os.system(MY_PYTHON_STRING + " " + scriptName)
  188. if sys.platform.startswith("win"):
  189. os.system("powershell" + " " + MY_PYTHON_STRING + " " + scriptName)
  190. elif sys.platform.startswith("linux"):
  191. os.system(MY_PYTHON_STRING + " " + scriptName)
  192. else:
  193. print("MASSIVE ERROR LOL!")
  194. exit(-4)
  195. # ~ loaderPID = subprocess.Popen([MY_PYTHON_STRING, scriptName])
  196. # ~ if loaderPID is not None:
  197. # ~ loaderPID.wait()
  198. # ~ else:
  199. # ~ print("MASSIVE ERROR LOL!")
  200. os.chdir(startDirectory)
  201. #Runs a system command. Input is the string that would run on linux or inside wsl.
  202. def runSystemCommand(inputString):
  203. if sys.platform.startswith("win"):
  204. os.system("wsl " + inputString)
  205. elif sys.platform.startswith("linux"):
  206. os.system(inputString)
  207. else:
  208. print("MASSIVE ERROR LOL!")
  209. exit(-4)
  210. #save copy of source code.
  211. def saveCopyOfSourceCode(thisOutputFolder):
  212. thisFileName = os.path.basename(__file__)
  213. try:
  214. shutil.copy(thisFileName, os.path.join(thisOutputFolder, "copy-" + thisFileName))
  215. except:
  216. print("Failed to make a copy of the source code!")
  217. # model.predict() makes an array of probabilities that a certian class is correct.
  218. # By saving the scores from the test_ds, we can see which images
  219. # cause false-positives, false-negatives, true-positives, and true-negatives
  220. def evaluateLabels(test_ds, model, outputFolder, missclassifiedFolder, batchSize):
  221. print("Getting predictions of test data...")
  222. testScores = model.predict(test_ds, verbose = True)
  223. actual_test_labels = extractLabels(test_ds)
  224. #Get the list of class predictions from the probability scores.
  225. p_test_labels = getPredictedLabels(testScores)
  226. saveMisclassified(test_ds, actual_test_labels, p_test_labels, missclassifiedFolder, batchSize)
  227. printLabelStuffToFile(testScores, actual_test_labels, p_test_labels, outputFolder) # debug function
  228. outString = "Confusion Matrix:\n"
  229. outString += "Bobcat(0), Coyote(1), Deer(2), Elk(3), Human(4), Not Interesting(5), Raccoon(6), Weasel(7)\n"
  230. cf = str(confusion_matrix(actual_test_labels, p_test_labels))
  231. cf_report = classification_report(actual_test_labels, p_test_labels, digits=4)
  232. outString += cf + "\n" + cf_report + "\n"
  233. #Make a pretty chart of these images?
  234. return outString
  235. # Saves all missclassified images
  236. def saveMisclassified(dataset, labels, predicted, missClassifiedFolder, batchSize):
  237. cnt = 0
  238. for img, _ in dataset.take(-1):
  239. for i in range(batchSize):
  240. if labels[cnt] != predicted[cnt]:
  241. myImg = np.asarray(img)
  242. thisActualName = CLASS_NAMES_LIST_STR[labels[cnt]]
  243. thisPredictedName = CLASS_NAMES_LIST_STR[predicted[cnt]]
  244. thisFileString = \
  245. "actual_" + thisActualName \
  246. + "_predicted_" + thisPredictedName \
  247. + "_" + str(cnt) + ".jpg"
  248. path = os.path.join(missClassifiedFolder, thisFileString)
  249. saveThis = np.asarray(myImg[i]) * 255
  250. cv2.imwrite(path, saveThis)
  251. if cnt < len(labels) - 1:
  252. cnt += 1
  253. else:
  254. return
  255. # Creates the necessary directories.
  256. def makeDirectories(listOfFoldersToCreate):
  257. for folder in listOfFoldersToCreate:
  258. if not os.path.isdir(folder):
  259. os.makedirs(folder)
  260. def deleteDirectories(listDirsToDelete):
  261. for folder in listDirsToDelete:
  262. if os.path.isdir(folder):
  263. shutil.rmtree(folder, ignore_errors = True)
  264. # add checkpointer, earlystopper?
  265. def trainModel(model, train_ds, val_ds, checkpointFolder, numEpochs, numPatience):
  266. checkpointer = callbacks.ModelCheckpoint(
  267. filepath = checkpointFolder,
  268. monitor = "accuracy",
  269. save_best_only = True,
  270. mode = "max")
  271. earlyStopper = callbacks.EarlyStopping( \
  272. monitor="val_accuracy", \
  273. mode = "max",
  274. patience = numPatience, \
  275. restore_best_weights = True)
  276. callbacks_list = [earlyStopper, checkpointer]
  277. return model.fit(
  278. train_ds,
  279. # ~ steps_per_epoch = 1, #to shorten training for testing purposes. I got no gpu qq.
  280. callbacks = callbacks_list,
  281. epochs = numEpochs,
  282. validation_data = val_ds)
  283. #Returns caption strings for the graphs of the accuracy and loss
  284. ## also returns the accuracy of the model as applied to the test dataset.
  285. def saveGraphs(model, myHistory, test_ds, outputFolder):
  286. evalLoss, evalAccuracy = model.evaluate(test_ds)
  287. plt.clf()
  288. accuracy = myHistory.history['accuracy']
  289. val_accuracy = myHistory.history["val_accuracy"]
  290. epochs = range(1, len(accuracy) + 1)
  291. accCap = round(evalAccuracy, 4)
  292. captionTextAcc = "Accuracy on test data: {}".format(accCap)
  293. plt.figtext(0.5, 0.01, captionTextAcc, wrap=True, horizontalalignment='center', fontsize=12)
  294. plt.plot(epochs, accuracy, "o", label="Training accuracy")
  295. plt.plot(epochs, val_accuracy, "^", label="Validation accuracy")
  296. plt.title("Model Accuracy vs Epochs")
  297. plt.ylabel("accuracy")
  298. plt.xlabel("epoch")
  299. plt.legend()
  300. plt.savefig(os.path.join(outputFolder, "trainvalacc.png"))
  301. plt.clf()
  302. loss = myHistory.history["loss"]
  303. val_loss = myHistory.history["val_loss"]
  304. lossCap = round(evalLoss, 4)
  305. captionTextLoss = "Loss on test data: {}".format(lossCap)
  306. plt.figtext(0.5, 0.01, captionTextLoss, wrap=True, horizontalalignment='center', fontsize=12)
  307. plt.plot(epochs, loss, "o", label="Training loss")
  308. plt.plot(epochs, val_loss, "^", label="Validation loss")
  309. plt.title("Training and validation loss vs Epochs")
  310. plt.ylabel("loss")
  311. plt.xlabel("epoch")
  312. plt.legend()
  313. plt.savefig(os.path.join(outputFolder, "trainvalloss.png"))
  314. plt.clf()
  315. return captionTextAcc, captionTextLoss, evalAccuracy
  316. def getDatasets(trainDir, valDir, testDir):
  317. train = tf.data.experimental.load(trainDir)
  318. val = tf.data.experimental.load(valDir)
  319. test = tf.data.experimental.load(testDir)
  320. return train, val, test
  321. # Prints first nine images from the first batch of the dataset.
  322. # It's random as long as you shuffle the dataset! ;)
  323. def printSample(in_ds):
  324. plt.figure(figsize=(10, 10))
  325. for img, label in in_ds.take(1):
  326. # ~ for i in tqdm.tqdm(range(9)):
  327. for i in tqdm(range(9)):
  328. ax = plt.subplot(3, 3, i + 1)
  329. myImg = np.asarray(img)
  330. plt.imshow(np.asarray(myImg[i]), cmap="gray")
  331. plt.title( CLASS_NAMES_LIST_STR[ np.asarray(label[i]) ] )
  332. plt.axis("off")
  333. plt.show()
  334. plt.clf()
  335. # Extract the labels from the tensorflow dataset structure.
  336. def extractLabels(in_ds):
  337. print("Trying to get list out of test dataset...")
  338. lablist = []
  339. for batch in tqdm(in_ds):
  340. lablist.extend( np.asarray(batch[1]) )
  341. return np.asarray(lablist)
  342. def printStringToFile(fileName, textString, openMode):
  343. with open(fileName, openMode) as myFile:
  344. for character in textString:
  345. myFile.write(character)
  346. def printLabelStuffToFile(predictedScores, originalLabels, predictedLabels, outputFolder):
  347. with open(os.path.join(outputFolder, "predictionlists.txt"), "w") as outFile:
  348. for i in range(len(predictedScores)):
  349. thisScores = predictedScores[i]
  350. thisString = "predicted scores: ["
  351. for animalClass in CLASS_NAMES_LIST_INT:
  352. thisString += str(round(thisScores[animalClass], 4))
  353. if len(thisScores) - 1 != animalClass:
  354. thisString += ", "
  355. thisString += "]" \
  356. + "\tactual label " + str(originalLabels[i]) \
  357. + "\tpredicted label" + str(predictedLabels[i]) \
  358. + "\n"
  359. outFile.write(thisString)
  360. def getPredictedLabels(testScores):
  361. outList = []
  362. for score in testScores:
  363. outList.append(np.argmax(score))
  364. return np.asarray(outList)
  365. if __name__ == '__main__':
  366. import sys
  367. sys.exit(main(sys.argv))