my-unet.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # my-unet.py
  5. #
  6. # Copyright 2022 Stephen Stengel <stephen@cwu.edu> MIT License
  7. #
  8. #OK to start out I'll use 2017 as test data, and 2016 as evaluation data.
  9. #Also I will resize each image to a square of size 128x128; I can add more complex manipulation later.
  10. #using these as reference for some parts:
  11. #https://github.com/dexterfichuk/U-Net-Keras-Implementation
  12. print("Running imports...")
  13. import os
  14. import shutil
  15. import numpy as np
  16. import random
  17. from PIL import Image
  18. from tqdm import tqdm
  19. from skimage.io import imread, imshow, imsave
  20. from skimage.transform import resize
  21. from matplotlib import pyplot as plt
  22. from skimage.util import img_as_bool, img_as_float, img_as_uint, img_as_ubyte
  23. from skimage.util import invert
  24. from skimage.color import rgb2gray, gray2rgb, rgb2hsv, hsv2rgb
  25. from sklearn.metrics import auc
  26. import time
  27. import tensorflow as tf
  28. from keras.layers import Conv2D, MaxPool2D, UpSampling2D, Input, Dropout, Lambda, MaxPooling2D, Conv2DTranspose, Concatenate, Softmax
  29. from tensorflow.keras.optimizers import Adam
  30. from keras import Model, callbacks
  31. from keras import backend
  32. # ~ from rezaunet import BCDU_net_D3
  33. ## autoinit stuff
  34. # ~ from autoinit import AutoInit
  35. np.random.seed(55555)
  36. random.seed(55555)
  37. NUM_SQUARES = 100 #Reduced number of square inputs for training. 100 seems to be min for ok results.
  38. HACK_SIZE = 64 #64 is reasonably good for prototyping.
  39. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  40. IMAGE_CHANNELS = 3 #This might change later for different datasets. idk.
  41. GLOBAL_EPOCHS = 15
  42. GLOBAL_BATCH_SIZE = 4 #just needs to be big enough to fill memory
  43. #64hack, 5 epoch, 16batch nearly fills 8gb on laptop. Half of 16 on other laptop.
  44. #Making batch too high seems to cause problems. 32 caused a NaN error when trying to write the output images on laptop1.
  45. GLOBAL_INITIAL_FILTERS = 16
  46. GLOBAL_SMOOTH_JACCARD = 1
  47. GLOBAL_SMOOTH_DICE = 1
  48. IS_GLOBAL_PRINTING_ON = False
  49. # ~ IS_GLOBAL_PRINTING_ON = True
  50. GLOBAL_SQUARE_TEST_SAVE = True
  51. # ~ GLOBAL_SQUARE_TEST_SAVE = False
  52. GLOBAL_MAX_TEST_SQUARE_TO_SAVE = 66
  53. HELPFILE_PATH = os.path.normpath("helpfile")
  54. OUT_TEXT_PATH = os.path.normpath("accuracies-if-error-happens-lol")
  55. print("Done!")
  56. def main(args):
  57. startTime = time.time()
  58. print("Hi!")
  59. checkArgs(args)
  60. print("Creating folders to store results...")
  61. tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, \
  62. predictionsFolder, wholePredictionsFolder, outTextPath \
  63. = createFolders()
  64. print("Done!")
  65. print("Creating copy of source code and conda environment...")
  66. copySourceEnv(tmpFolder)
  67. print("Done!")
  68. trainImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  69. trainTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  70. testImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  71. testTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  72. # ~ testImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  73. # ~ testTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  74. # ~ trainImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  75. # ~ trainTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  76. print("Creating train and test sets...")
  77. trainImages, trainTruth, testImages, \
  78. testTruths, wholeOriginals, wholeTruths = \
  79. createTrainAndTestSets(trainImagePath, trainTruthPath, testImagePath, testTruthPath)
  80. print("Done!")
  81. #Images not currently called from disk. Commenting for speed testing.
  82. # ~ saveExperimentImages(trainImages, trainTruth, testImages, testTruths, trainingFolder)
  83. if IS_GLOBAL_PRINTING_ON:
  84. mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths)
  85. trainImages, trainTruth, testImages, testTruths \
  86. = reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, NUM_SQUARES)
  87. theModel, theHistory = trainUnet(trainImages, trainTruth, checkpointFolder)
  88. print("Saving model...")
  89. theModel.save(os.path.join(savedModelFolder, "saved-model.h5"))
  90. print("Done!")
  91. performEvaluation(theHistory, tmpFolder, testImages, testTruths, theModel)
  92. print("shape of testImages right before predict: " + str(np.shape(testImages)))
  93. modelOut = theModel.predict(testImages)
  94. binarizedOut = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8) #######test this thing more
  95. if GLOBAL_SQUARE_TEST_SAVE:
  96. saveTestSquares(
  97. GLOBAL_MAX_TEST_SQUARE_TO_SAVE, modelOut, \
  98. binarizedOut, testImages, testTruths, predictionsFolder)
  99. else:
  100. print("Not saving test square pictures this time.")
  101. # ~ print("Calculating jaccard and dice for the test squares...")
  102. # ~ calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut)
  103. print("Predicting output of whole images...")
  104. #currently also does the image processing and saving.
  105. predictionsList = predictAllWholeImages(wholeOriginals, wholeTruths, theModel, HACK_SIZE)
  106. print("Creating confusion masks...")
  107. confusionImages, tpList, fpList, tnList, fnList \
  108. = createConfusionImageList(predictionsList, wholeOriginals, wholeTruths)
  109. print("Saving whole image predictions and confusion images...")
  110. saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusionImages, wholePredictionsFolder)
  111. print("Creating ROC graph of the whole images test...")
  112. createROC(tpList, fpList, tnList, fnList, tmpFolder)
  113. # ~ print("Evaluating jaccard and dice scores...")
  114. # ~ evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath)
  115. print("Done!")
  116. saveRuntimeToFile(startTime, tmpFolder)
  117. return 0
  118. def saveRuntimeToFile(startTime, tmpFolder):
  119. elapsedTime = time.time() - startTime
  120. with open(os.path.join(tmpFolder, "runtime.txt"), "w") as timeFile:
  121. timeFile.write(str(round(elapsedTime, 4)) + " seconds")
  122. def createFolders():
  123. sq = str(NUM_SQUARES)
  124. hk = str(HACK_SIZE)
  125. ep = str(GLOBAL_EPOCHS)
  126. ba = str(GLOBAL_BATCH_SIZE)
  127. tmpFolder = os.path.normpath("./tmp" + sq + "-" + hk + "-" + ep + "-" + ba + "/")
  128. trainingFolder = os.path.join(tmpFolder, "trainingstuff")
  129. checkpointFolder = os.path.join(tmpFolder, "checkpoint")
  130. savedModelFolder = os.path.join(tmpFolder, "saved-model")
  131. predictionsFolder = os.path.join(tmpFolder, "predictions")
  132. wholePredictionsFolder = os.path.join(tmpFolder, "whole-predictions")
  133. foldersToCreate = [ \
  134. tmpFolder, trainingFolder, \
  135. checkpointFolder, savedModelFolder, \
  136. predictionsFolder, wholePredictionsFolder]
  137. for folder in foldersToCreate:
  138. if not os.path.isdir(folder):
  139. os.makedirs(folder)
  140. #Lol spaghetti
  141. global OUT_TEXT_PATH
  142. OUT_TEXT_PATH = os.path.join(tmpFolder, "accuracy-jaccard-dice.txt")
  143. return tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, predictionsFolder, wholePredictionsFolder, OUT_TEXT_PATH
  144. def copySourceEnv(tmpFolder):
  145. try:
  146. shutil.copy("my-unet.py", tmpFolder)
  147. shutil.copy("working-conda-config.yml", tmpFolder)
  148. except:
  149. print("copy error! Source code and conda environment file not copied!")
  150. #This would be preferred in the final product.
  151. # ~ print("Creating current copy of environment...")
  152. # ~ os.system("conda env export > " + tmpFolder + "working-conda-config-current.yml")
  153. # ~ print("Done!")
  154. def reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, sizeOfSet ):
  155. #This block reduces the input for testing.
  156. highIndex = len(trainImages)
  157. if sizeOfSet > highIndex + 1: #Just in case user enters more squares than exist.
  158. sizeOfSet = highIndex + 1
  159. print("! Limiting size of squares for training to actual number of squares !")
  160. print("Number of squares to be used for training: " + str(sizeOfSet))
  161. updateGlobalNumSquares(sizeOfSet)
  162. rng = np.random.default_rng(12345)
  163. pickIndexes = rng.integers(low = 0, high = highIndex, size = sizeOfSet)
  164. trainImages = trainImages[pickIndexes]
  165. trainTruth = trainTruth[pickIndexes]
  166. sizeOfTestSet = sizeOfSet
  167. if sizeOfTestSet > len(testImages):
  168. sizeOfTestSet = len(testImages)
  169. rng = np.random.default_rng(23456)
  170. print("sizeOfTestSet: " + str(sizeOfTestSet))
  171. pickIndexes = rng.integers(low = 0, high = len(testImages), size = sizeOfTestSet)
  172. testImages = testImages[pickIndexes]
  173. testTruths = testTruths[pickIndexes]
  174. print("There are " + str(len(trainImages)) + " training images.")
  175. print("There are " + str(len(testImages)) + " testing images.")
  176. return trainImages, trainTruth, testImages, testTruths
  177. def mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths):
  178. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  179. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  180. print("shape of trainImages: " + str(np.shape(trainImages)))
  181. print("shape of trainTruth: " + str(np.shape(trainTruth)))
  182. print("shape of testImages: " + str(np.shape(testImages)))
  183. print("shape of testTruths: " + str(np.shape(testTruths)))
  184. print("Showing Training stuff...")
  185. randomBoy = random.randint(0, len(trainImages) - 1)
  186. print("image " + str(randomBoy) + "...")
  187. imshow(trainImages[randomBoy] / 255)
  188. # ~ plt.show()
  189. print("truth " + str(randomBoy) + "...")
  190. imshow(np.squeeze(trainTruth[randomBoy]))
  191. # ~ plt.show()
  192. print("Showing Testing stuff...")
  193. randomBoy = random.randint(0, len(testImages) - 1)
  194. print("image " + str(randomBoy) + "...")
  195. imshow(testImages[randomBoy] / 255)
  196. # ~ plt.show()
  197. print("truth " + str(randomBoy) + "...")
  198. imshow(np.squeeze(testTruths[randomBoy]))
  199. # ~ plt.show()
  200. #save copies of some of the squares used in learning.
  201. def saveTestSquares(numToSave, modelOut, binarizedOut, testImages, testTruths, predictionsFolder):
  202. print("Saving random sample of figures...")
  203. rng2 = np.random.default_rng(54322)
  204. if len(modelOut) < numToSave:
  205. numToSave = len(modelOut)
  206. saveIndexes = rng2.integers(low = 0, high = len(modelOut), size = numToSave)
  207. for i in tqdm(saveIndexes):
  208. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]premask.png"), img_as_ubyte(modelOut[i]))
  209. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]predict.png"), img_as_ubyte(binarizedOut[i]))
  210. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]testimg.png"), img_as_ubyte(testImages[i]))
  211. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]truthim.png"), img_as_ubyte(testTruths[i]))
  212. print("Done!")
  213. def calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut):
  214. testTruthsUInt = testTruths.astype(np.uint8)
  215. #Testing the jaccard and dice functions
  216. with open(os.path.join(outTextPath), "w") as outFile:
  217. for i in tqdm(range(len(binarizedOut))):
  218. jac = jaccardIndex(testTruthsUInt[i], binarizedOut[i])
  219. dice = diceIndex(testTruthsUInt[i], binarizedOut[i])
  220. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  221. outFile.write(thisString)
  222. print("Done!")
  223. #currently also does the image processing and saving.
  224. def predictAllWholeImages(wholeOriginals, wholeTruths, theModel, squareSize):
  225. if IS_GLOBAL_PRINTING_ON:
  226. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  227. for i in range(len(wholeOriginals)):
  228. print(str(np.shape(wholeOriginals[i])))
  229. print("##########################################################")
  230. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  231. predictionsList = []
  232. for i in tqdm(range(len(wholeOriginals))):
  233. # ~ wholeOriginals, wholeTruths
  234. predictedImage = predictWholeImage(wholeOriginals[i], theModel, squareSize)
  235. if IS_GLOBAL_PRINTING_ON:
  236. print("Shape of predicted image " + str(i) + ": " + str(np.shape(predictedImage)))
  237. # ~ predictedImage = ((predictedImage > 0.5).astype(np.uint8) * 255).astype(np.uint8) ## jank thing again
  238. # ~ print("Shape of predicted image " + str(i) + " after mask: " + str(np.shape(predictedImage)))
  239. predictionsList.append(predictedImage)
  240. return predictionsList
  241. #This could be split up a bit as well.
  242. #also outputs the binary masks in lists !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! change later
  243. def createConfusionImageList(predictionsList, wholeOriginals, wholeTruths):
  244. confusionList = []
  245. tpList = []
  246. fpList = []
  247. tnList = []
  248. fnList = []
  249. for i in tqdm(range(len(predictionsList))):
  250. predictedImage = predictionsList[i]
  251. truePosMask = createMaskTruePositive(wholeTruths[i], predictedImage)
  252. tpList.append(truePosMask)
  253. trueNegMask = createMaskTrueNegative(wholeTruths[i], predictedImage)
  254. tnList.append(trueNegMask)
  255. falsePosMask = createMaskFalsePositive(wholeTruths[i], predictedImage)
  256. fpList.append(falsePosMask)
  257. falseNegMask = createMaskFalseNegative(wholeTruths[i], predictedImage)
  258. fnList.append(falseNegMask)
  259. redColor = [1, 0, 0]
  260. greenColor = [0, 1, 0]
  261. blueColor = [0, 0, 1]
  262. yellowColor = [1, 1, 0]
  263. truePosColor = colorPredictionWithPredictionMask(truePosMask, predictedImage, greenColor)
  264. trueNegColor = colorPredictionWithPredictionMask(trueNegMask, predictedImage, yellowColor)
  265. falsePosColor = colorPredictionWithPredictionMask(falsePosMask, predictedImage, blueColor)
  266. falseNegColor = colorPredictionWithPredictionMask(falseNegMask, predictedImage, redColor )
  267. confusion = combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask)
  268. confusionList.append(confusion)
  269. return confusionList, tpList, fpList, tnList, fnList
  270. def saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusions, wholePredictionsFolder):
  271. for i in tqdm(range(len(predictionsList))):
  272. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]predicted.png"), img_as_ubyte(predictionsList[i]))
  273. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]truth.png"), img_as_ubyte(wholeTruths[i]))
  274. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]original.png"), img_as_ubyte(wholeOriginals[i]))
  275. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]confusion.png"), img_as_ubyte(confusions[i]))
  276. def createROC(tpList, fpList, tnList, fnList, tmpFolder):
  277. #convert each list of masks into a list of percentages.
  278. tpScores = getPercentTrueFromMaskList(tpList)
  279. fpScores = getPercentTrueFromMaskList(fpList)
  280. #not currently needed
  281. # ~ tnScores = getPercentTrueFromMaskList(tnList[i])
  282. # ~ fnScores = getPercentTrueFromMaskList(fnList[i])
  283. plotROCandSave(fpScores, tpScores, tmpFolder)
  284. def plotROCandSave(fpList, tpList, tmpFolder):
  285. # ~ [x for (y,x) in sorted(zip(Y,X), key=lambda pair: pair[0])]
  286. fpArray = np.asarray(fpList)
  287. tpArray = np.asarray(tpList)
  288. indexesSorted = fpArray.argsort()
  289. fpArray = fpArray[indexesSorted]
  290. tpArray = tpArray[indexesSorted]
  291. print("fpArray: " + str(fpArray))
  292. print("tpArray: " + str(tpArray))
  293. # ~ print("fpArray: " + str(fpArray))
  294. # ~ print("tpArray: " + str(tpArray))
  295. roc_auc = auc(fpArray, tpArray)
  296. print("contents of roc_auc: " + str(roc_auc))
  297. # ~ for thing in roc_auc:
  298. # ~ print(thing)
  299. # ~ print("##")
  300. linewidth = 2
  301. plt.figure()
  302. plt.plot(
  303. fpArray,
  304. tpArray,
  305. color="darkorange",
  306. linewidth = linewidth,
  307. label="ROC curve (area = %0.2f%%)" % roc_auc,
  308. # ~ label = "ROC curve"
  309. )
  310. plt.plot([0, 1], [0, 1], color="navy", linewidth=linewidth, linestyle="--")
  311. plt.xlim([0.0, 1.0])
  312. # ~ plt.ylim([0.0, 1.0])
  313. plt.ylim([0.0, 1.05])
  314. plt.xlabel("False Positive Rate")
  315. plt.ylabel("True Positive Rate")
  316. plt.title("ROC curve !")
  317. plt.legend(loc="lower right")
  318. plt.savefig(os.path.join(tmpFolder, "roc-curve.png"))
  319. # ~ plt.show()
  320. def getPercentTrueFromMaskList(inList):
  321. scoreList = []
  322. for i in tqdm(range(len(inList))):
  323. scoreList.append( getPercentTrueFromMask(inList[i]))
  324. return scoreList
  325. ######################################################### I actually don't need percent. I neet TPRa and FPR (tp, fp, tn, fn) ##########################
  326. def getPercentTrueFromMask(inMask):
  327. # ~ mFlat = backend.flatten( img_as_uint(inMask) )
  328. # ~ numTrue = backend.sum(mFlat)
  329. mFlat = np.asarray(inMask).flatten()
  330. numTrue = np.sum(mFlat)
  331. # ~ mFlat = np.asarray(mFlat)
  332. totalNum = mFlat.size
  333. return numTrue / totalNum
  334. def evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath):
  335. print("Calculating jaccard and dice...")
  336. with open(outTextPath, "w") as outFile:
  337. for i in tqdm(range(len(predictionsList))):
  338. thisTruth = np.asarray(wholeTruths[i])
  339. thisTruth = thisTruth.astype(np.uint8)
  340. jac = jaccardIndex(thisTruth, predictionsList[i])
  341. dice = diceIndex(thisTruth, predictionsList[i])
  342. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  343. outFile.write(thisString)
  344. print("Done!")
  345. def checkArgs(args):
  346. if len(args) >= 1:
  347. for a in args:
  348. if str(a) == "help" \
  349. or str(a).lower() == "-help" \
  350. or str(a).lower() == "--help" \
  351. or str(a).lower() == "--h":
  352. with open(HELPFILE_PATH, "r") as helpfile:
  353. for line in helpfile:
  354. print(line, end = "")
  355. sys.exit(0)
  356. if len(args) < 5:
  357. print("bad input");
  358. sys.exit(-1)
  359. else:
  360. global NUM_SQUARES
  361. NUM_SQUARES = int(sys.argv[1])
  362. global HACK_SIZE
  363. HACK_SIZE = int(sys.argv[2])
  364. global GLOBAL_HACK_height
  365. global GLOBAL_HACK_width
  366. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  367. global GLOBAL_EPOCHS
  368. GLOBAL_EPOCHS = int(sys.argv[3])
  369. global GLOBAL_BATCH_SIZE
  370. GLOBAL_BATCH_SIZE = int(sys.argv[4])
  371. if len(args) >= 6:
  372. if str(sys.argv[5]) == "print":
  373. global IS_GLOBAL_PRINTING_ON
  374. IS_GLOBAL_PRINTING_ON = True
  375. print("Printing of debugging messages is enabled.")
  376. if NUM_SQUARES < 100:
  377. print("100 squares is really the bare minimum to get any meaningful result.")
  378. sys.exit(-1)
  379. if HACK_SIZE not in [64, 128, 256, 512]:
  380. print("Square size must be 64, 128, 256, or 512." \
  381. + " 128 is recommended for training. 64 for testing")
  382. sys.exit(-2)
  383. if GLOBAL_EPOCHS < 1:
  384. print("Yeah no.")
  385. print("You need at least one epoch, silly!")
  386. sys.exit(-3)
  387. if GLOBAL_BATCH_SIZE < 1 or GLOBAL_BATCH_SIZE > NUM_SQUARES:
  388. print("Global batch size should be between 1 and the number" \
  389. + " of training squares. Pick a better number.")
  390. sys.exit(-5)
  391. def updateGlobalNumSquares(newNumSquares):
  392. global NUM_SQUARES
  393. NUM_SQUARES = newNumSquares
  394. def performEvaluation(history, tmpFolder, testImages, testTruths, theModel):
  395. print("Performing evaluation...###############################################################")
  396. print("Calculating scores...")
  397. print("len testImages: " + str(len(testImages)))
  398. scores = theModel.evaluate(testImages, testTruths)
  399. print("Done!")
  400. print("Scores object: " + str(scores))
  401. print(str(history.history))
  402. print("%s: %.2f%%" % (theModel.metrics_names[1], scores[1]*100))
  403. print("history...")
  404. print(history)
  405. print("history.history...")
  406. print(history.history)
  407. accuracy = history.history["acc"]
  408. jaccInd = history.history["jaccardIndex"]
  409. diceInd = history.history["diceIndex"]
  410. val_accuracy = history.history["val_acc"]
  411. val_jaccInd = history.history["val_jaccardIndex"]
  412. val_diceInd = history.history["val_diceIndex"]
  413. loss = history.history["loss"]
  414. val_loss = history.history["val_loss"]
  415. epochs = range(1, len(accuracy) + 1)
  416. #make a loop for this bit?
  417. saveGraphNumbers(accuracy, epochs, "acc", tmpFolder)
  418. saveGraphNumbers(jaccInd, epochs, "jaccardIndex", tmpFolder)
  419. saveGraphNumbers(diceInd, epochs, "diceIndex", tmpFolder)
  420. saveGraphNumbers(val_accuracy, epochs, "val_acc", tmpFolder)
  421. saveGraphNumbers(val_jaccInd, epochs, "val_jaccardIndex", tmpFolder)
  422. saveGraphNumbers(val_diceInd, epochs, "val_diceIndex", tmpFolder)
  423. saveGraphNumbers(loss, epochs, "loss", tmpFolder)
  424. saveGraphNumbers(val_loss, epochs, "val_loss", tmpFolder)
  425. plt.plot(epochs, accuracy, "^", label="Training accuracy")
  426. plt.plot(epochs, val_accuracy, "2", label="Validation accuracy")
  427. plt.plot(epochs, jaccInd, "*", label="Jaccard Index")
  428. plt.plot(epochs, val_jaccInd, "p", label="Validation Jaccard Index")
  429. plt.plot(epochs, diceInd, "s", label="Dice Index")
  430. plt.plot(epochs, val_diceInd, "D", label="Validation Dice Index")
  431. plt.title("Training and validation accuracy")
  432. plt.legend()
  433. plt.savefig(os.path.join(tmpFolder, "trainvalacc.png"))
  434. plt.clf()
  435. plt.plot(epochs, loss, "^", label="Training loss")
  436. plt.plot(epochs, val_loss, "2", label="Validation loss")
  437. plt.title("Training and validation loss")
  438. plt.legend()
  439. plt.savefig(os.path.join(tmpFolder, "trainvalloss.png"))
  440. plt.clf()
  441. def saveGraphNumbers(array, epochs, nameString, tmpFolder):
  442. outFolder = os.path.join(tmpFolder, "graph-numbers")
  443. if not os.path.isdir(outFolder):
  444. os.makedirs(outFolder)
  445. with open(os.path.join(outFolder, nameString + ".txt"), "w") as statFile:
  446. statFile.write(nameString + " over " + str(epochs) + "epochs" + "\n")
  447. for thing in array:
  448. statFile.write(str(thing) + "\n")
  449. def trainUnet(trainImages, trainTruth, checkpointFolder):
  450. #print("shape of trainImages: " + str(trainImages.shape))
  451. standardUnetLol = createStandardUnet()
  452. # ~ standardUnetLol = BCDU_net_D3( (GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS) )
  453. standardUnetLol.summary()
  454. checkpointer = callbacks.ModelCheckpoint(
  455. filepath = checkpointFolder,
  456. # ~ monitor = "val_acc", #current working version
  457. # ~ monitor = "val_loss", #original ##################################################
  458. # ~ monitor = "val_jaccardIndex",
  459. monitor = "jaccardIndex",
  460. # ~ monitor = "val_jaccardLoss",
  461. save_best_only = True,
  462. mode = "max")
  463. # ~ mode = "min")
  464. earlyStopper = callbacks.EarlyStopping( \
  465. monitor="val_jaccardIndex", \
  466. patience = 5, \
  467. mode = "max", \
  468. #not sure about resotre weights...
  469. restore_best_weights = True)
  470. # ~ callbacks_list = [earlyStopper, checkpointer]
  471. callbacks_list = [checkpointer]
  472. myHistory = standardUnetLol.fit(
  473. x = trainImages,
  474. y = trainTruth,
  475. epochs = GLOBAL_EPOCHS,
  476. batch_size = GLOBAL_BATCH_SIZE,
  477. callbacks = callbacks_list,
  478. validation_split = 0.33333)
  479. return standardUnetLol, myHistory
  480. def createStandardUnet():
  481. input_size=(GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS)
  482. inputs = Input(input_size)
  483. conv5, conv4, conv3, conv2, conv1 = encode(inputs)
  484. output = decode(conv5, conv4, conv3, conv2, conv1)
  485. model = Model(inputs, output)
  486. # ~ autoinit test. Uncomment to add the autoinit thingy
  487. # ~ model = AutoInit().initialize_model(model)
  488. # ~ model.compile(optimizer = Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=["acc"])
  489. model.compile(
  490. # ~ optimizer = "adam",
  491. optimizer = Adam(),
  492. loss = "binary_crossentropy",
  493. # ~ loss = jaccardLoss,
  494. metrics = ["acc", jaccardIndex, diceIndex])
  495. return model
  496. #dropout increase in middle to reduce runtime in addition to dropping out stuff.
  497. def encode(inputs):
  498. sfilter = GLOBAL_INITIAL_FILTERS
  499. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(inputs)
  500. conv1 = Dropout(0.1)(conv1)
  501. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv1)
  502. pool1 = MaxPooling2D((2, 2))(conv1)
  503. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(pool1)
  504. conv2 = Dropout(0.1)(conv2)
  505. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv2)
  506. pool2 = MaxPooling2D((2, 2))(conv2)
  507. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(pool2)
  508. conv3 = Dropout(0.2)(conv3)
  509. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv3)
  510. pool3 = MaxPooling2D((2, 2))(conv3)
  511. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(pool3)
  512. conv4 = Dropout(0.2)(conv4)
  513. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv4)
  514. pool4 = MaxPooling2D((2, 2))(conv4)
  515. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(pool4)
  516. conv5 = Dropout(0.3)(conv5)
  517. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(conv5)
  518. return conv5, conv4, conv3, conv2, conv1
  519. def decode(conv5, conv4, conv3, conv2, conv1):
  520. sfilter = GLOBAL_INITIAL_FILTERS
  521. up6 = Conv2DTranspose(sfilter * 8, (2, 2), strides = (2, 2), padding = "same")(conv5)
  522. concat6 = Concatenate()([conv4,up6])
  523. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(concat6)
  524. conv6 = Dropout(0.2)(conv6)
  525. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv6)
  526. up7 = Conv2DTranspose(sfilter * 4, (2, 2), strides = (2, 2), padding = "same")(conv6)
  527. concat7 = Concatenate()([conv3,up7])
  528. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(concat7)
  529. conv7 = Dropout(0.2)(conv7)
  530. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv7)
  531. up8 = Conv2DTranspose(sfilter * 2, (2, 2), strides = (2, 2), padding = "same")(conv7)
  532. concat8 = Concatenate()([conv2,up8])
  533. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(concat8)
  534. conv8 = Dropout(0.1)(conv8)
  535. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv8)
  536. up9 = Conv2DTranspose(sfilter, (2, 2), strides = (2, 2), padding = "same")(conv8)
  537. concat9 = Concatenate()([conv1,up9])
  538. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(concat9)
  539. conv9 = Dropout(0.1)(conv9)
  540. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv9)
  541. conv10 = Conv2D(1, (1, 1), padding = "same", activation = "sigmoid")(conv9)
  542. return conv10
  543. def saveExperimentImages(trainImages, trainTruth, testImages, testTruths, tmpFolder):
  544. if not os.path.exists(tmpFolder):
  545. print("Making a tmp folder...")
  546. os.system("mkdir tmp")
  547. print("Done!")
  548. np.save(tmpFolder + "train-images-object", trainImages)
  549. np.save(tmpFolder + "train-truth-object", trainTruth)
  550. np.save(tmpFolder + "test-images-object", testImages)
  551. np.save(tmpFolder + "test-truth-object", testTruths)
  552. #gets images from file, manipulates, returns
  553. #currently hardcoded to use 2017 as training data, and 2016 as testing
  554. #data because their names are regular!
  555. def createTrainAndTestSets(trainImagePath, trainTruthPath, testImagePath, testTruthPath):
  556. trainImageFileNames, trainTruthFileNames, \
  557. testImageFileNames, testTruthFileNames = getFileNames(trainImagePath, trainTruthPath, testImagePath, testTruthPath)
  558. trainImages, trainTruth, _, _ = getImageAndTruth(trainImageFileNames, trainTruthFileNames)
  559. trainTruth = convertImagesToGrayscale(trainTruth)
  560. testImage, testTruth, wholeOriginals, wholeTruths = getImageAndTruth(testImageFileNames, testTruthFileNames)
  561. testTruth = convertImagesToGrayscale(testTruth)
  562. wholeTruths = convertImagesToGrayscaleList(wholeTruths)
  563. #invert the imported images. Tensorflow counts white as truth
  564. #and black as false. I had been doing the inverse previously.
  565. trainImages = invertImagesInArray(trainImages)
  566. trainTruth = invertImagesInArray(trainTruth)
  567. testImage = invertImagesInArray(testImage)
  568. testTruth = invertImagesInArray(testTruth)
  569. wholeOriginals = invertImagesInArray(wholeOriginals)
  570. wholeTruths = invertImagesInArray(wholeTruths)
  571. return trainImages, trainTruth, testImage, testTruth, wholeOriginals, wholeTruths
  572. #Inverts all the images in an array. returns an array.
  573. def invertImagesInArray(imgArray):
  574. for i in range(len(imgArray)):
  575. imgArray[i] = invert(imgArray[i])
  576. return imgArray
  577. #This function gets the source image, cuts it into smaller squares, then
  578. #adds each square to an array for output. The original image squares
  579. #will correspond to the base truth squares.
  580. #Try using a method from here to avoid using lists on the arrays:
  581. #https://stackoverflow.com/questions/50226821/how-to-extend-numpy-arrray
  582. #Also returns a copy of the original uncut images as lists.
  583. def getImageAndTruth(originalFilenames, truthFilenames):
  584. outOriginals, outTruths = [], []
  585. wholeOriginals = []
  586. wholeTruths = []
  587. print("Importing " + originalFilenames[0] + " and friends...")
  588. for i in tqdm(range(len(originalFilenames))):
  589. # ~ print("\rImporting " + originalFilenames[i] + "...", end = "")
  590. myOriginal = imread(originalFilenames[i])[:, :, :3] #this is pretty arcane. research later
  591. myTruth = imread(truthFilenames[i])[:, :, :3] #this is pretty arcane. research later
  592. #save original images as list for returning to main
  593. thisOriginal = myOriginal ##Test before removing these temp vals.
  594. thisTruth = myTruth
  595. wholeOriginals.append(np.asarray(thisOriginal))
  596. wholeTruths.append(np.asarray(thisTruth))
  597. #Now make the cuts and save the results to a list. Then later convert list to array.
  598. originalCuts = cutImageIntoSmallSquares(myOriginal)
  599. truthCuts = cutImageIntoSmallSquares(myTruth)
  600. #for loop to add cuts to out lists, or I think I remember a one liner to do it?
  601. #yes-- list has the .extend() function. it adds the elements of a list to another list.
  602. outOriginals.extend(originalCuts)
  603. outTruths.extend(truthCuts)
  604. #can move to return line later maybe.
  605. outOriginals, outTruths = np.asarray(outOriginals), np.asarray(outTruths)
  606. return outOriginals, outTruths, wholeOriginals, wholeTruths
  607. #Cut an image into smaller squares, returns them as a list.
  608. #inspiration from:
  609. #https://stackoverflow.com/questions/5953373/how-to-split-image-into-multiple-pieces-in-python#7051075
  610. #Change to using numpy methods later for much speed-up?:
  611. #https://towardsdatascience.com/efficiently-splitting-an-image-into-tiles-in-python-using-numpy-d1bf0dd7b6f7?gi=2faa21fa5964
  612. #The input is in scikit-image format. It is converted to pillow to crop
  613. #more easily and for saving???. Then converted back for the output list.
  614. #Whitespace is appended to the right and bottom of the image so that the crop will include everything.
  615. def cutImageIntoSmallSquares(skImage):
  616. skOutList = []
  617. myImage = Image.fromarray(skImage)
  618. imageWidth, imageHeight = myImage.size
  619. tmpW = ((imageWidth // HACK_SIZE) + 1) * HACK_SIZE
  620. tmpH = ((imageHeight // HACK_SIZE) + 1) * HACK_SIZE
  621. #Make this next line (0,0,0) once you switch the words to white and background to black.........##############################################################################
  622. tmpImg = Image.new(myImage.mode, (tmpW, tmpH), (255, 255, 255))
  623. wHehe, hHehe = myImage.size
  624. heheHack = (0, 0, wHehe, hHehe)
  625. # ~ tmpImg.paste(myImage, myImage.getbbox())
  626. if IS_GLOBAL_PRINTING_ON:
  627. print("tmpImg.mode: " + str(tmpImg.mode))
  628. print("tmpImg.getbbox(): " + str(tmpImg.getbbox()))
  629. print("tmpImg.size: " + str(tmpImg.size))
  630. print("myImage.mode: " + str(myImage.mode))
  631. print("myImage.getbbox(): " + str(myImage.getbbox()))
  632. print("myImage width, height: " + "(" + str(imageWidth) + "," + str(imageHeight) + ")")
  633. print("myImage.size: " + str(myImage.size))
  634. print("heheHack: " + str(heheHack))
  635. tmpImg.paste(myImage, heheHack)
  636. myImage = tmpImg
  637. # ~ tmp2 = np.asarray(myImage)
  638. # ~ imshow(tmp2)
  639. # ~ plt.show()
  640. for upper in range(0, imageHeight, HACK_SIZE):
  641. lower = upper + HACK_SIZE
  642. for left in range(0, imageWidth, HACK_SIZE):
  643. right = left + HACK_SIZE
  644. cropBounds = (left, upper, right, lower)
  645. cropped = myImage.crop(cropBounds)
  646. cropped = np.asarray(cropped)
  647. skOutList.append(cropped)
  648. # ~ imshow(cropped / 255)
  649. # ~ plt.show()
  650. return skOutList
  651. #This function cuts a large input image into little squares, uses the
  652. #trained model to predict the binarization of each, then stitches each
  653. #image back into a whole for output.
  654. def predictWholeImage(inputImage, theModel, squareSize):
  655. if IS_GLOBAL_PRINTING_ON:
  656. print("squareSize: " + str(squareSize))
  657. ##get dimensions of the image
  658. height, width, _ = inputImage.shape
  659. ##get the number of squares per row of the image
  660. squaresWide = (width // squareSize) + 1
  661. widthPlusRightBuffer = squaresWide * squareSize
  662. squaresHigh = (height // squareSize) + 1
  663. heightPlusBottomBumper = squaresHigh * squareSize
  664. #Dice the image into bits
  665. if IS_GLOBAL_PRINTING_ON:
  666. print("shape of input Image right before dicing: " + str(np.shape(inputImage)))
  667. # ~ print("input Image right before dicing as string: " + str(inputImage))
  668. dicedImage = cutImageIntoSmallSquares(inputImage)
  669. # ~ print("shape of dicedImage right before hacking: " + str(np.shape(dicedImage)))
  670. # ~ #put output into list with extend then np.asaray the whole list to match elswhere.
  671. # ~ tmpList = []
  672. # ~ for i in range(len(dicedImage)):
  673. # ~ tmpList.extend(dicedImage[i])
  674. # ~ dicedImage = np.asarray(tmpList)
  675. ##Predict the outputs of each square
  676. dicedImage = np.asarray(dicedImage)
  677. # ~ print("shape of dicedImage right before predict: " + str(np.shape(dicedImage)))
  678. # ~ print("dicedImage right before predict as string: " + str(dicedImage))
  679. modelOut = theModel.predict(dicedImage)
  680. ##This is the code from main. I know it's bad now, but I'll keep it
  681. ##consistent until I create a helper function for it. ######################################################################################################
  682. binarizedOuts = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8)
  683. #Stitch image using dimensions from above
  684. #combine each image row into numpy array
  685. theRowsList = []
  686. # ~ print("squaresHigh: " + str(squaresHigh))
  687. # ~ print("squaresWide: " + str(squaresWide))
  688. # ~ print("squareSize: " + str(squareSize))
  689. bigOut = np.zeros(shape = (squareSize * squaresHigh, squareSize * squaresWide, 1), dtype = np.uint8) #swap h and w?
  690. for i in range(squaresHigh):
  691. for j in range(squaresWide):
  692. # ~ print("i: " + str(i) + "\tj: " + str(j))
  693. # ~ print("sqHi: " + str(squaresHigh) + "\tsqWi: " + str(squaresWide))
  694. thisSquare = binarizedOuts[(i * squaresWide) + j] #w?
  695. iStart = i * squareSize
  696. iEnd = (i * squareSize) + squareSize
  697. jStart = j * squareSize
  698. jEnd = (j * squareSize) + squareSize
  699. bigOut[iStart : iEnd , jStart : jEnd ] = thisSquare
  700. # ~ combined = np.asarray(theRowsList)
  701. # ~ combined = combined.reshape((64,64,1))
  702. #Remove the extra padding from the edge of the image.
  703. outImage = bigOut[ :height, :width]
  704. # ~ outImage = bigOut
  705. return outImage
  706. def convertImagesToGrayscale(inputImages):
  707. outImage = []
  708. for image in inputImages:
  709. outImage.append( rgb2gray(image) )
  710. return np.asarray(outImage)
  711. #Returns a list instead of an np array
  712. def convertImagesToGrayscaleList(inputImages):
  713. outImage = []
  714. for image in inputImages:
  715. outImage.append( np.asarray(rgb2gray(image)) )
  716. return outImage
  717. #returns the filenames of the images for (trainImage, trainTruth),(testimage, testTruth)
  718. #hardcoded!
  719. #Test is currently hardcodded to 2016
  720. def getFileNames(trainImagePath, trainTruthPath, testImagePath, testTruthPath):
  721. trainTruthNamePairs = []
  722. # ~ trainImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  723. # ~ trainTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  724. trainTruthNamePairs.append( (trainImagePath, trainTruthPath) )
  725. # ~ trainImageFileNames, trainTruthFileNames = \
  726. # ~ createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath)
  727. #need to handle non-bmp
  728. # ~ trainPath = "../DIBCO/2009/DIBC02009_Test_images-handwritten/"
  729. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_handwritten/"
  730. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  731. # ~ trainPath = "../DIBCO/2009/DIBCO2009_Test_images-printed/"
  732. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_printed/"
  733. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  734. #non-bmps
  735. # ~ trainPath = "../DIBCO/2010/DIBC02010_Test_images/"
  736. # ~ gtPath = "../DIBCO/2010/DIBC02010_Test_GT/"
  737. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  738. #2011 needs a special function to split the GTs with a wildcard or something.
  739. #2012 same
  740. # ~ trainPath = "../DIBCO/2013/OriginalImages/"
  741. # ~ gtPath = "../DIBCO/2013/GTimages/"
  742. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  743. # ~ trainPath = "../DIBCO/2014/original_images/"
  744. # ~ gtPath = "../DIBCO/2014/gt/"
  745. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  746. trainImageFileNames = []
  747. trainTruthFileNames = []
  748. for pair in trainTruthNamePairs:
  749. tImPath, gtImPath = pair
  750. trainNames, gtNames = \
  751. createTrainImageAndTrainTruthFileNames(tImPath, gtImPath)
  752. trainImageFileNames.extend(trainNames)
  753. trainTruthFileNames.extend(gtNames)
  754. #test image section
  755. # ~ testImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  756. # ~ testTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  757. testImageFileNames, testTruthFileNames = \
  758. createTrainImageAndTrainTruthFileNames(testImagePath, testTruthPath)
  759. return trainImageFileNames, trainTruthFileNames, \
  760. testImageFileNames, testTruthFileNames
  761. def createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath):
  762. trainImageFileNames = createTrainImageFileNamesList(trainImagePath)
  763. trainTruthFileNames = createTrainTruthFileNamesList(trainImageFileNames)
  764. trainImageFileNames = appendBMP(trainImageFileNames)
  765. # ~ print(trainImageFileNames)
  766. trainTruthFileNames = appendBMP(trainTruthFileNames)
  767. # ~ print(trainTruthFileNames)
  768. trainImageFileNames = prependPath(trainImagePath, trainImageFileNames)
  769. trainTruthFileNames = prependPath(trainTruthPath, trainTruthFileNames)
  770. return trainImageFileNames, trainTruthFileNames
  771. def createTrainImageFileNamesList(trainImagePath):
  772. # ~ trainFileNames = next(os.walk(trainImagePath))[2] #this is a clever hack
  773. # ~ trainFileNames = [name.replace(".bmp", "") for name in trainFileNames]
  774. trainFileNames = os.listdir(trainImagePath)
  775. print(trainFileNames)
  776. # ~ print("pausing...")
  777. # ~ a = input()
  778. return [name.replace(".bmp", "") for name in trainFileNames]
  779. #This makes a list with the same order of the names but with _gt apended.
  780. def createTrainTruthFileNamesList(originalNames):
  781. return [name + "_gt" for name in originalNames]
  782. def appendBMP(inputList):
  783. return [name + ".bmp" for name in inputList]
  784. def prependPath(myPath, nameList):
  785. return [os.path.join(myPath, name) for name in nameList]
  786. #I'm copying the code for jaccard similarity and dice from this MIT licenced source.
  787. #https://github.com/masyagin1998/robin
  788. #jaccard is size intersection of the sets / size union of the sets
  789. #Also, I'm going to try the smoothing values suggested in robin and here:
  790. #https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
  791. #They also suggest abs()
  792. def jaccardIndex(truth, prediction):
  793. #they are tensors?! not images?!?!?!?!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  794. # ~ truth = img_as_bool(np.asarray(truth)) #fail
  795. # ~ prediction = img_as_bool(np.asarray(prediction)) #fail
  796. smooth = GLOBAL_SMOOTH_JACCARD
  797. predictionFlat = backend.flatten(prediction)
  798. truthFlat = backend.flatten(truth)
  799. # ~ for i in range(len(predictionFlat)):
  800. # ~ if predictionFlat[i] >= 2:
  801. # ~ print("This was the bug. images have non- binary values...#############################################################################################################################################################")
  802. # ~ print("predictionFlat[" + str(i) + "]: " + str(predictionFlat[i]))
  803. # ~ intersectionImg = predictionFlat * truthFlat
  804. numberPixelsSame = backend.sum(truthFlat * predictionFlat)
  805. #I've found the function tensorflow.reduce_sum() which performs a sum by reduction
  806. #Is it better than backend.sum?? ##################################################################
  807. #the docs say it is equivalent except that numpy will change everything to int64
  808. return float((numberPixelsSame + smooth) / \
  809. ( \
  810. (backend.sum(predictionFlat) + backend.sum(truthFlat) - numberPixelsSame + smooth) \
  811. ))
  812. #loss function for use in training.
  813. def jaccardLoss(truth, prediction):
  814. return 1.0 - jaccardIndex(truth, prediction)
  815. #input must be binarized images consisting of values for pixels of either 1 or 0.
  816. def diceIndex(truth, prediction):
  817. smooth = GLOBAL_SMOOTH_DICE
  818. predictionFlat = backend.flatten(prediction)
  819. truthFlat = backend.flatten(truth)
  820. numberSamePixels = backend.sum(predictionFlat * truthFlat)
  821. return float((2 * numberSamePixels + smooth) \
  822. / (backend.sum(predictionFlat) + backend.sum(truthFlat) + smooth))
  823. #Loss function for use in training
  824. def diceLoss(truth, prediction):
  825. smooth = GLOBAL_SMOOTH_DICE
  826. return smooth - diceIndex(truth, prediction)
  827. #This creates a true positive mask from an inverted image (white pixels are text)
  828. def createMaskTruePositive(truth, prediction):
  829. pFlat = backend.flatten(prediction)
  830. pFlat = img_as_bool(pFlat)
  831. tFlat = backend.flatten(truth)
  832. tFlat = img_as_bool(tFlat)
  833. mask = pFlat * tFlat
  834. return np.reshape(mask, np.shape(prediction))
  835. #creaes true negative mask
  836. def createMaskTrueNegative(truth, prediction):
  837. pFlat = backend.flatten(prediction)
  838. pFlat = img_as_bool(pFlat)
  839. tFlat = backend.flatten(truth)
  840. tFlat = img_as_bool(tFlat)
  841. ##invert to make the 0 into ones
  842. pFlat = ~pFlat
  843. tFlat = ~tFlat
  844. ##then multiply
  845. mask = pFlat * tFlat
  846. return np.reshape(mask, np.shape(prediction))
  847. #Creates a mask for the False Positives
  848. def createMaskFalsePositive(truth, prediction):
  849. pFlat = backend.flatten(prediction)
  850. pFlat = img_as_bool(pFlat)
  851. tFlat = backend.flatten(truth)
  852. tFlat = img_as_bool(tFlat)
  853. #will I need these?
  854. # ~ falseArray = np.zeros(len(pFlat), dtype = bool)
  855. # ~ trueArray = np.ones(len(pFlat), dtype = bool)
  856. ##where is the prediction true 1, where the truth is 0 false?
  857. mask = np.where(pFlat > tFlat, True, False)
  858. # ~ mask = np.where(pFlat > tFlat, pFlat, ~pFlat)
  859. return np.reshape(mask, np.shape(prediction))
  860. #returns a mask of all the pixels that are not supposed to be false.
  861. def createMaskFalseNegative(truth, prediction):
  862. # ~ return createMaskFalsePositive(prediction, truth) #Just swapped the input!?? yes but bug happens. 3-1
  863. pFlat = backend.flatten(prediction)
  864. pFlat = img_as_bool(pFlat)
  865. tFlat = backend.flatten(truth)
  866. tFlat = img_as_bool(tFlat)
  867. mask = np.where(pFlat < tFlat, True, False)
  868. return np.reshape(mask, np.shape(prediction))
  869. #Color the prediction image with the pixels that are correct in red.
  870. def colorPredictionWithPredictionMask(predictionMask, originalPrediction, colorArray):
  871. prediction = img_as_bool(originalPrediction)
  872. prediction = np.where( predictionMask >= prediction, True, False ) #This makes the area to paint to white.
  873. prediction = img_as_float(prediction)
  874. predictionMask = np.squeeze(predictionMask, axis = 2)
  875. if IS_GLOBAL_PRINTING_ON:
  876. print("predictionMask shape: " + str(predictionMask.shape))
  877. rows, cols = predictionMask.shape
  878. colorMask = np.zeros((rows, cols, 3))
  879. colorMask[ predictionMask ] = colorArray
  880. predictionInColor = np.dstack((prediction, prediction, prediction))
  881. predictionColor_hsv = rgb2hsv(predictionInColor)
  882. colorMask_hsv = rgb2hsv(colorMask)
  883. alpha = 1.0
  884. predictionColor_hsv[..., 0] = colorMask_hsv[..., 0]
  885. predictionColor_hsv[..., 1] = colorMask_hsv[..., 1] * alpha
  886. outImg = hsv2rgb(predictionColor_hsv)
  887. return img_as_ubyte(outImg)
  888. def combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask):
  889. redColor = [1, 0, 0]
  890. greenColor = [0, 1, 0]
  891. blueColor = [0, 0, 1]
  892. yellowColor = [1, 1, 0]
  893. truePosMask = np.squeeze(truePosMask, axis = 2)
  894. trueNegMask = np.squeeze(trueNegMask, axis = 2)
  895. falsePosMask = np.squeeze(falsePosMask, axis = 2)
  896. falseNegMask = np.squeeze(falseNegMask, axis = 2)
  897. ##make a numpy array of ONES, reshape to image size, then convert with imgtofloat
  898. # ~ print("truePosMask.shape: " + str(truePosMask.shape))
  899. rows, cols = truePosMask.shape
  900. # ~ prediction = np.ones( (rows, cols), dtype=bool )
  901. prediction = np.zeros( (rows, cols), dtype=bool )
  902. # ~ prediction = img_as_float(prediction)
  903. predictionRGB = np.dstack((prediction, prediction, prediction))
  904. # ~ predictionColor_hsv = rgb2hsv(predictionRGB)
  905. ## make the four color masks
  906. redColorMask = img_as_bool(falseNegMask)
  907. blueColorMask = img_as_bool(falsePosMask)
  908. greenColorMask = img_as_bool(truePosMask)
  909. yellowColorMask = img_as_bool(trueNegMask)
  910. predictionRGB[redColorMask, 0 ] = 1
  911. predictionRGB[greenColorMask, 1 ] = 1
  912. predictionRGB[blueColorMask, 2 ] = 1
  913. predictionRGB[yellowColorMask, 0 ] = 1
  914. predictionRGB[yellowColorMask, 1 ] = 1
  915. # ~ green_image = rgb_image.copy() # Make a copy
  916. # ~ green_image[:,:,0] = 0
  917. # ~ green_image[:,:,2] = 0
  918. # ~ alpha = 1.0
  919. # ~ predictionColor_hsv[..., 0] = redColorMask[..., 0]
  920. # ~ predictionColor_hsv[..., 1] = redColorMask[..., 1] * alpha
  921. # ~ predictionColor_hsv[..., 0] = blueColorMask[..., 0]
  922. # ~ predictionColor_hsv[..., 1] = blueColorMask[..., 1] * alpha
  923. # ~ predictionColor_hsv[..., 0] = greenColorMask[..., 0]
  924. # ~ predictionColor_hsv[..., 1] = greenColorMask[..., 1] * alpha
  925. # ~ predictionColor_hsv[..., 0] = yellowColorMask[..., 0]
  926. # ~ predictionColor_hsv[..., 1] = yellowColorMask[..., 1] * alpha
  927. # ~ outImg = hsv2rgb(predictionColor_hsv)
  928. # ~ return img_as_ubyte(outImg)
  929. return img_as_ubyte(predictionRGB)
  930. def makeThisColorMaskHsv(predictionMask, colorArray, rows, cols):
  931. thisColorMask = np.zeros((rows, cols, 3))
  932. thisColorMask[ predictionMask ] = colorArray
  933. return rgb2hsv(thisColorMask)
  934. def makeThisColorMaskRGB(predictionMask, colorArray, rows, cols):
  935. thisColorMask = np.zeros((rows, cols, 3))
  936. thisColorMask[ predictionMask ] = colorArray
  937. return thisColorMask
  938. if __name__ == '__main__':
  939. import sys
  940. sys.exit(main(sys.argv))