my-unet.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # my-unet.py
  5. #
  6. # Copyright 2022 Stephen Stengel <stephen@cwu.edu> MIT License
  7. #
  8. #OK to start out I'll use 2017 as test data, and 2016 as evaluation data.
  9. #Also I will resize each image to a square of size 128x128; I can add more complex manipulation later.
  10. #using these as reference for some parts:
  11. #https://github.com/dexterfichuk/U-Net-Keras-Implementation
  12. # There is currently a bug in the image manipulation part causing
  13. # occasional index overflow. I might just re-write this whole thing
  14. # in pytorch anyway...
  15. # The graphs of training accuracy show that the accuracy goes down A LOT
  16. # Also, sometimes it just gets stuck marking the whole image as
  17. # background.
  18. print("Running imports...")
  19. import os
  20. import shutil
  21. import numpy as np
  22. import random
  23. from PIL import Image
  24. from tqdm import tqdm
  25. from skimage.io import imread, imshow, imsave
  26. from skimage.transform import resize
  27. from matplotlib import pyplot as plt
  28. from skimage.util import img_as_bool, img_as_float, img_as_uint, img_as_ubyte
  29. from skimage.util import invert
  30. from skimage.color import rgb2gray, gray2rgb, rgb2hsv, hsv2rgb
  31. from sklearn.metrics import auc
  32. import time
  33. import tensorflow as tf
  34. from keras.layers import Conv2D, MaxPool2D, UpSampling2D, Input, Dropout, Lambda, MaxPooling2D, Conv2DTranspose, Concatenate, Softmax
  35. from tensorflow.keras.optimizers import Adam
  36. from keras import Model, callbacks
  37. from keras import backend
  38. # ~ from rezaunet import BCDU_net_D3
  39. ## autoinit stuff
  40. # ~ from autoinit import AutoInit
  41. np.random.seed(55555)
  42. random.seed(55555)
  43. NUM_SQUARES = 100 #Reduced number of square inputs for training. 100 seems to be min for ok results.
  44. HACK_SIZE = 64 #64 is reasonably good for prototyping.
  45. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  46. IMAGE_CHANNELS = 3 #This might change later for different datasets. idk.
  47. GLOBAL_EPOCHS = 15
  48. GLOBAL_BATCH_SIZE = 4 #just needs to be big enough to fill memory
  49. #64hack, 5 epoch, 16batch nearly fills 8gb on laptop. Half of 16 on other laptop.
  50. #Making batch too high seems to cause problems. 32 caused a NaN error when trying to write the output images on laptop1.
  51. GLOBAL_INITIAL_FILTERS = 16
  52. GLOBAL_SMOOTH_JACCARD = 1
  53. GLOBAL_SMOOTH_DICE = 1
  54. IS_GLOBAL_PRINTING_ON = False
  55. # ~ IS_GLOBAL_PRINTING_ON = True
  56. GLOBAL_SQUARE_TEST_SAVE = True
  57. # ~ GLOBAL_SQUARE_TEST_SAVE = False
  58. GLOBAL_MAX_TEST_SQUARE_TO_SAVE = 66
  59. HELPFILE_PATH = os.path.normpath("helpfile")
  60. OUT_TEXT_PATH = os.path.normpath("accuracies-if-error-happens-lol")
  61. print("Done!")
  62. def main(args):
  63. startTime = time.time()
  64. print("Hi!")
  65. checkArgs(args)
  66. print("Creating folders to store results...")
  67. tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, \
  68. predictionsFolder, wholePredictionsFolder, outTextPath \
  69. = createFolders()
  70. print("Done!")
  71. print("Creating copy of source code and conda environment...")
  72. copySourceEnv(tmpFolder)
  73. print("Done!")
  74. trainImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  75. trainTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  76. testImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  77. testTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  78. # ~ testImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  79. # ~ testTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  80. # ~ trainImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  81. # ~ trainTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  82. print("Creating train and test sets...")
  83. trainImages, trainTruth, testImages, \
  84. testTruths, wholeOriginals, wholeTruths = \
  85. createTrainAndTestSets(trainImagePath, trainTruthPath, testImagePath, testTruthPath)
  86. print("Done!")
  87. #Images not currently called from disk. Commenting for speed testing.
  88. # ~ saveExperimentImages(trainImages, trainTruth, testImages, testTruths, trainingFolder)
  89. if IS_GLOBAL_PRINTING_ON:
  90. mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths)
  91. trainImages, trainTruth, testImages, testTruths \
  92. = reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, NUM_SQUARES)
  93. theModel, theHistory = trainUnet(trainImages, trainTruth, checkpointFolder)
  94. print("Saving model...")
  95. theModel.save(os.path.join(savedModelFolder, "saved-model.h5"))
  96. print("Done!")
  97. performEvaluation(theHistory, tmpFolder, testImages, testTruths, theModel)
  98. print("shape of testImages right before predict: " + str(np.shape(testImages)))
  99. modelOut = theModel.predict(testImages)
  100. binarizedOut = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8) #######test this thing more
  101. if GLOBAL_SQUARE_TEST_SAVE:
  102. saveTestSquares(
  103. GLOBAL_MAX_TEST_SQUARE_TO_SAVE, modelOut, \
  104. binarizedOut, testImages, testTruths, predictionsFolder)
  105. else:
  106. print("Not saving test square pictures this time.")
  107. # ~ print("Calculating jaccard and dice for the test squares...")
  108. # ~ calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut)
  109. print("Predicting output of whole images...")
  110. #currently also does the image processing and saving.
  111. predictionsList = predictAllWholeImages(wholeOriginals, wholeTruths, theModel, HACK_SIZE)
  112. print("Creating confusion masks...")
  113. confusionImages, tpList, fpList, tnList, fnList \
  114. = createConfusionImageList(predictionsList, wholeOriginals, wholeTruths)
  115. print("Saving whole image predictions and confusion images...")
  116. saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusionImages, wholePredictionsFolder)
  117. print("Creating ROC graph of the whole images test...")
  118. createROC(tpList, fpList, tnList, fnList, tmpFolder)
  119. # ~ print("Evaluating jaccard and dice scores...")
  120. # ~ evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath)
  121. print("Done!")
  122. saveRuntimeToFile(startTime, tmpFolder)
  123. return 0
  124. def saveRuntimeToFile(startTime, tmpFolder):
  125. elapsedTime = time.time() - startTime
  126. with open(os.path.join(tmpFolder, "runtime.txt"), "w") as timeFile:
  127. timeFile.write(str(round(elapsedTime, 4)) + " seconds")
  128. def createFolders():
  129. sq = str(NUM_SQUARES)
  130. hk = str(HACK_SIZE)
  131. ep = str(GLOBAL_EPOCHS)
  132. ba = str(GLOBAL_BATCH_SIZE)
  133. tmpFolder = os.path.normpath("./tmp" + sq + "-" + hk + "-" + ep + "-" + ba + "/")
  134. trainingFolder = os.path.join(tmpFolder, "trainingstuff")
  135. checkpointFolder = os.path.join(tmpFolder, "checkpoint")
  136. savedModelFolder = os.path.join(tmpFolder, "saved-model")
  137. predictionsFolder = os.path.join(tmpFolder, "predictions")
  138. wholePredictionsFolder = os.path.join(tmpFolder, "whole-predictions")
  139. foldersToCreate = [ \
  140. tmpFolder, trainingFolder, \
  141. checkpointFolder, savedModelFolder, \
  142. predictionsFolder, wholePredictionsFolder]
  143. for folder in foldersToCreate:
  144. if not os.path.isdir(folder):
  145. os.makedirs(folder)
  146. #Lol spaghetti
  147. global OUT_TEXT_PATH
  148. OUT_TEXT_PATH = os.path.join(tmpFolder, "accuracy-jaccard-dice.txt")
  149. return tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, predictionsFolder, wholePredictionsFolder, OUT_TEXT_PATH
  150. def copySourceEnv(tmpFolder):
  151. try:
  152. shutil.copy("my-unet.py", tmpFolder)
  153. shutil.copy("working-conda-config.yml", tmpFolder)
  154. except:
  155. print("copy error! Source code and conda environment file not copied!")
  156. #This would be preferred in the final product.
  157. # ~ print("Creating current copy of environment...")
  158. # ~ os.system("conda env export > " + tmpFolder + "working-conda-config-current.yml")
  159. # ~ print("Done!")
  160. def reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, sizeOfSet ):
  161. #This block reduces the input for testing.
  162. highIndex = len(trainImages)
  163. if sizeOfSet > highIndex + 1: #Just in case user enters more squares than exist.
  164. sizeOfSet = highIndex + 1
  165. print("! Limiting size of squares for training to actual number of squares !")
  166. print("Number of squares to be used for training: " + str(sizeOfSet))
  167. updateGlobalNumSquares(sizeOfSet)
  168. rng = np.random.default_rng(12345)
  169. pickIndexes = rng.integers(low = 0, high = highIndex, size = sizeOfSet)
  170. trainImages = trainImages[pickIndexes]
  171. trainTruth = trainTruth[pickIndexes]
  172. sizeOfTestSet = sizeOfSet
  173. if sizeOfTestSet > len(testImages):
  174. sizeOfTestSet = len(testImages)
  175. rng = np.random.default_rng(23456)
  176. print("sizeOfTestSet: " + str(sizeOfTestSet))
  177. pickIndexes = rng.integers(low = 0, high = len(testImages), size = sizeOfTestSet)
  178. testImages = testImages[pickIndexes]
  179. testTruths = testTruths[pickIndexes]
  180. print("There are " + str(len(trainImages)) + " training images.")
  181. print("There are " + str(len(testImages)) + " testing images.")
  182. return trainImages, trainTruth, testImages, testTruths
  183. def mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths):
  184. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  185. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  186. print("shape of trainImages: " + str(np.shape(trainImages)))
  187. print("shape of trainTruth: " + str(np.shape(trainTruth)))
  188. print("shape of testImages: " + str(np.shape(testImages)))
  189. print("shape of testTruths: " + str(np.shape(testTruths)))
  190. print("Showing Training stuff...")
  191. randomBoy = random.randint(0, len(trainImages) - 1)
  192. print("image " + str(randomBoy) + "...")
  193. imshow(trainImages[randomBoy] / 255)
  194. # ~ plt.show()
  195. print("truth " + str(randomBoy) + "...")
  196. imshow(np.squeeze(trainTruth[randomBoy]))
  197. # ~ plt.show()
  198. print("Showing Testing stuff...")
  199. randomBoy = random.randint(0, len(testImages) - 1)
  200. print("image " + str(randomBoy) + "...")
  201. imshow(testImages[randomBoy] / 255)
  202. # ~ plt.show()
  203. print("truth " + str(randomBoy) + "...")
  204. imshow(np.squeeze(testTruths[randomBoy]))
  205. # ~ plt.show()
  206. #save copies of some of the squares used in learning.
  207. def saveTestSquares(numToSave, modelOut, binarizedOut, testImages, testTruths, predictionsFolder):
  208. print("Saving random sample of figures...")
  209. rng2 = np.random.default_rng(54322)
  210. if len(modelOut) < numToSave:
  211. numToSave = len(modelOut)
  212. saveIndexes = rng2.integers(low = 0, high = len(modelOut), size = numToSave)
  213. for i in tqdm(saveIndexes):
  214. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]premask.png"), img_as_ubyte(modelOut[i]))
  215. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]predict.png"), img_as_ubyte(binarizedOut[i]))
  216. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]testimg.png"), img_as_ubyte(testImages[i]))
  217. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]truthim.png"), img_as_ubyte(testTruths[i]))
  218. print("Done!")
  219. def calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut):
  220. testTruthsUInt = testTruths.astype(np.uint8)
  221. #Testing the jaccard and dice functions
  222. with open(os.path.join(outTextPath), "w") as outFile:
  223. for i in tqdm(range(len(binarizedOut))):
  224. jac = jaccardIndex(testTruthsUInt[i], binarizedOut[i])
  225. dice = diceIndex(testTruthsUInt[i], binarizedOut[i])
  226. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  227. outFile.write(thisString)
  228. print("Done!")
  229. #currently also does the image processing and saving.
  230. def predictAllWholeImages(wholeOriginals, wholeTruths, theModel, squareSize):
  231. if IS_GLOBAL_PRINTING_ON:
  232. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  233. for i in range(len(wholeOriginals)):
  234. print(str(np.shape(wholeOriginals[i])))
  235. print("##########################################################")
  236. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  237. predictionsList = []
  238. for i in tqdm(range(len(wholeOriginals))):
  239. # ~ wholeOriginals, wholeTruths
  240. predictedImage = predictWholeImage(wholeOriginals[i], theModel, squareSize)
  241. if IS_GLOBAL_PRINTING_ON:
  242. print("Shape of predicted image " + str(i) + ": " + str(np.shape(predictedImage)))
  243. # ~ predictedImage = ((predictedImage > 0.5).astype(np.uint8) * 255).astype(np.uint8) ## jank thing again
  244. # ~ print("Shape of predicted image " + str(i) + " after mask: " + str(np.shape(predictedImage)))
  245. predictionsList.append(predictedImage)
  246. return predictionsList
  247. #This could be split up a bit as well.
  248. #also outputs the binary masks in lists !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! change later
  249. def createConfusionImageList(predictionsList, wholeOriginals, wholeTruths):
  250. confusionList = []
  251. tpList = []
  252. fpList = []
  253. tnList = []
  254. fnList = []
  255. for i in tqdm(range(len(predictionsList))):
  256. predictedImage = predictionsList[i]
  257. truePosMask = createMaskTruePositive(wholeTruths[i], predictedImage)
  258. tpList.append(truePosMask)
  259. trueNegMask = createMaskTrueNegative(wholeTruths[i], predictedImage)
  260. tnList.append(trueNegMask)
  261. falsePosMask = createMaskFalsePositive(wholeTruths[i], predictedImage)
  262. fpList.append(falsePosMask)
  263. falseNegMask = createMaskFalseNegative(wholeTruths[i], predictedImage)
  264. fnList.append(falseNegMask)
  265. redColor = [1, 0, 0]
  266. greenColor = [0, 1, 0]
  267. blueColor = [0, 0, 1]
  268. yellowColor = [1, 1, 0]
  269. truePosColor = colorPredictionWithPredictionMask(truePosMask, predictedImage, greenColor)
  270. trueNegColor = colorPredictionWithPredictionMask(trueNegMask, predictedImage, yellowColor)
  271. falsePosColor = colorPredictionWithPredictionMask(falsePosMask, predictedImage, blueColor)
  272. falseNegColor = colorPredictionWithPredictionMask(falseNegMask, predictedImage, redColor )
  273. confusion = combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask)
  274. confusionList.append(confusion)
  275. return confusionList, tpList, fpList, tnList, fnList
  276. def saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusions, wholePredictionsFolder):
  277. for i in tqdm(range(len(predictionsList))):
  278. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]predicted.png"), img_as_ubyte(predictionsList[i]))
  279. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]truth.png"), img_as_ubyte(wholeTruths[i]))
  280. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]original.png"), img_as_ubyte(wholeOriginals[i]))
  281. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]confusion.png"), img_as_ubyte(confusions[i]))
  282. def createROC(tpList, fpList, tnList, fnList, tmpFolder):
  283. #convert each list of masks into a list of percentages.
  284. tpScores = getPercentTrueFromMaskList(tpList)
  285. fpScores = getPercentTrueFromMaskList(fpList)
  286. #not currently needed
  287. # ~ tnScores = getPercentTrueFromMaskList(tnList[i])
  288. # ~ fnScores = getPercentTrueFromMaskList(fnList[i])
  289. plotROCandSave(fpScores, tpScores, tmpFolder)
  290. def plotROCandSave(fpList, tpList, tmpFolder):
  291. # ~ [x for (y,x) in sorted(zip(Y,X), key=lambda pair: pair[0])]
  292. fpArray = np.asarray(fpList)
  293. tpArray = np.asarray(tpList)
  294. indexesSorted = fpArray.argsort()
  295. fpArray = fpArray[indexesSorted]
  296. tpArray = tpArray[indexesSorted]
  297. print("fpArray: " + str(fpArray))
  298. print("tpArray: " + str(tpArray))
  299. # ~ print("fpArray: " + str(fpArray))
  300. # ~ print("tpArray: " + str(tpArray))
  301. roc_auc = auc(fpArray, tpArray)
  302. print("contents of roc_auc: " + str(roc_auc))
  303. # ~ for thing in roc_auc:
  304. # ~ print(thing)
  305. # ~ print("##")
  306. linewidth = 2
  307. plt.figure()
  308. plt.plot(
  309. fpArray,
  310. tpArray,
  311. color="darkorange",
  312. linewidth = linewidth,
  313. label="ROC curve (area = %0.2f%%)" % roc_auc,
  314. # ~ label = "ROC curve"
  315. )
  316. plt.plot([0, 1], [0, 1], color="navy", linewidth=linewidth, linestyle="--")
  317. plt.xlim([0.0, 1.0])
  318. # ~ plt.ylim([0.0, 1.0])
  319. plt.ylim([0.0, 1.05])
  320. plt.xlabel("False Positive Rate")
  321. plt.ylabel("True Positive Rate")
  322. plt.title("ROC curve !")
  323. plt.legend(loc="lower right")
  324. plt.savefig(os.path.join(tmpFolder, "roc-curve.png"))
  325. # ~ plt.show()
  326. def getPercentTrueFromMaskList(inList):
  327. scoreList = []
  328. for i in tqdm(range(len(inList))):
  329. scoreList.append( getPercentTrueFromMask(inList[i]))
  330. return scoreList
  331. ######################################################### I actually don't need percent. I neet TPRa and FPR (tp, fp, tn, fn) ##########################
  332. def getPercentTrueFromMask(inMask):
  333. # ~ mFlat = backend.flatten( img_as_uint(inMask) )
  334. # ~ numTrue = backend.sum(mFlat)
  335. mFlat = np.asarray(inMask).flatten()
  336. numTrue = np.sum(mFlat)
  337. # ~ mFlat = np.asarray(mFlat)
  338. totalNum = mFlat.size
  339. return numTrue / totalNum
  340. def evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath):
  341. print("Calculating jaccard and dice...")
  342. with open(outTextPath, "w") as outFile:
  343. for i in tqdm(range(len(predictionsList))):
  344. thisTruth = np.asarray(wholeTruths[i])
  345. thisTruth = thisTruth.astype(np.uint8)
  346. jac = jaccardIndex(thisTruth, predictionsList[i])
  347. dice = diceIndex(thisTruth, predictionsList[i])
  348. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  349. outFile.write(thisString)
  350. print("Done!")
  351. def checkArgs(args):
  352. if len(args) >= 1:
  353. for a in args:
  354. if str(a) == "help" \
  355. or str(a).lower() == "-help" \
  356. or str(a).lower() == "--help" \
  357. or str(a).lower() == "--h":
  358. with open(HELPFILE_PATH, "r") as helpfile:
  359. for line in helpfile:
  360. print(line, end = "")
  361. sys.exit(0)
  362. if len(args) < 5:
  363. print("bad input");
  364. sys.exit(-1)
  365. else:
  366. global NUM_SQUARES
  367. NUM_SQUARES = int(sys.argv[1])
  368. global HACK_SIZE
  369. HACK_SIZE = int(sys.argv[2])
  370. global GLOBAL_HACK_height
  371. global GLOBAL_HACK_width
  372. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  373. global GLOBAL_EPOCHS
  374. GLOBAL_EPOCHS = int(sys.argv[3])
  375. global GLOBAL_BATCH_SIZE
  376. GLOBAL_BATCH_SIZE = int(sys.argv[4])
  377. if len(args) >= 6:
  378. if str(sys.argv[5]) == "print":
  379. global IS_GLOBAL_PRINTING_ON
  380. IS_GLOBAL_PRINTING_ON = True
  381. print("Printing of debugging messages is enabled.")
  382. if NUM_SQUARES < 100:
  383. print("100 squares is really the bare minimum to get any meaningful result.")
  384. sys.exit(-1)
  385. if HACK_SIZE not in [64, 128, 256, 512]:
  386. print("Square size must be 64, 128, 256, or 512." \
  387. + " 128 is recommended for training. 64 for testing")
  388. sys.exit(-2)
  389. if GLOBAL_EPOCHS < 1:
  390. print("Yeah no.")
  391. print("You need at least one epoch, silly!")
  392. sys.exit(-3)
  393. if GLOBAL_BATCH_SIZE < 1 or GLOBAL_BATCH_SIZE > NUM_SQUARES:
  394. print("Global batch size should be between 1 and the number" \
  395. + " of training squares. Pick a better number.")
  396. sys.exit(-5)
  397. def updateGlobalNumSquares(newNumSquares):
  398. global NUM_SQUARES
  399. NUM_SQUARES = newNumSquares
  400. def performEvaluation(history, tmpFolder, testImages, testTruths, theModel):
  401. print("Performing evaluation...###############################################################")
  402. print("Calculating scores...")
  403. print("len testImages: " + str(len(testImages)))
  404. scores = theModel.evaluate(testImages, testTruths)
  405. print("Done!")
  406. print("Scores object: " + str(scores))
  407. print(str(history.history))
  408. print("%s: %.2f%%" % (theModel.metrics_names[1], scores[1]*100))
  409. print("history...")
  410. print(history)
  411. print("history.history...")
  412. print(history.history)
  413. accuracy = history.history["acc"]
  414. jaccInd = history.history["jaccardIndex"]
  415. diceInd = history.history["diceIndex"]
  416. val_accuracy = history.history["val_acc"]
  417. val_jaccInd = history.history["val_jaccardIndex"]
  418. val_diceInd = history.history["val_diceIndex"]
  419. loss = history.history["loss"]
  420. val_loss = history.history["val_loss"]
  421. epochs = range(1, len(accuracy) + 1)
  422. #make a loop for this bit?
  423. saveGraphNumbers(accuracy, epochs, "acc", tmpFolder)
  424. saveGraphNumbers(jaccInd, epochs, "jaccardIndex", tmpFolder)
  425. saveGraphNumbers(diceInd, epochs, "diceIndex", tmpFolder)
  426. saveGraphNumbers(val_accuracy, epochs, "val_acc", tmpFolder)
  427. saveGraphNumbers(val_jaccInd, epochs, "val_jaccardIndex", tmpFolder)
  428. saveGraphNumbers(val_diceInd, epochs, "val_diceIndex", tmpFolder)
  429. saveGraphNumbers(loss, epochs, "loss", tmpFolder)
  430. saveGraphNumbers(val_loss, epochs, "val_loss", tmpFolder)
  431. plt.plot(epochs, accuracy, "^", label="Training accuracy")
  432. plt.plot(epochs, val_accuracy, "2", label="Validation accuracy")
  433. plt.plot(epochs, jaccInd, "*", label="Jaccard Index")
  434. plt.plot(epochs, val_jaccInd, "p", label="Validation Jaccard Index")
  435. plt.plot(epochs, diceInd, "s", label="Dice Index")
  436. plt.plot(epochs, val_diceInd, "D", label="Validation Dice Index")
  437. plt.title("Training and validation accuracy")
  438. plt.legend()
  439. plt.savefig(os.path.join(tmpFolder, "trainvalacc.png"))
  440. plt.clf()
  441. plt.plot(epochs, loss, "^", label="Training loss")
  442. plt.plot(epochs, val_loss, "2", label="Validation loss")
  443. plt.title("Training and validation loss")
  444. plt.legend()
  445. plt.savefig(os.path.join(tmpFolder, "trainvalloss.png"))
  446. plt.clf()
  447. def saveGraphNumbers(array, epochs, nameString, tmpFolder):
  448. outFolder = os.path.join(tmpFolder, "graph-numbers")
  449. if not os.path.isdir(outFolder):
  450. os.makedirs(outFolder)
  451. with open(os.path.join(outFolder, nameString + ".txt"), "w") as statFile:
  452. statFile.write(nameString + " over " + str(epochs) + "epochs" + "\n")
  453. for thing in array:
  454. statFile.write(str(thing) + "\n")
  455. def trainUnet(trainImages, trainTruth, checkpointFolder):
  456. #print("shape of trainImages: " + str(trainImages.shape))
  457. standardUnetLol = createStandardUnet()
  458. # ~ standardUnetLol = BCDU_net_D3( (GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS) )
  459. standardUnetLol.summary()
  460. checkpointer = callbacks.ModelCheckpoint(
  461. filepath = checkpointFolder,
  462. monitor = "val_jaccardIndex",
  463. save_best_only = True,
  464. mode = "max")
  465. earlyStopper = callbacks.EarlyStopping( \
  466. monitor="val_jaccardIndex", \
  467. patience = 5, \
  468. mode = "max", \
  469. min_delta = 0.01, \
  470. #restore best weights restores the weights at the epoch
  471. #with the best results at the end of training.
  472. restore_best_weights = True)
  473. callbacks_list = [earlyStopper, checkpointer]
  474. myHistory = standardUnetLol.fit(
  475. x = trainImages,
  476. y = trainTruth,
  477. epochs = GLOBAL_EPOCHS,
  478. batch_size = GLOBAL_BATCH_SIZE,
  479. callbacks = callbacks_list,
  480. validation_split = 0.33333)
  481. return standardUnetLol, myHistory
  482. def createStandardUnet():
  483. input_size=(GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS)
  484. inputs = Input(input_size)
  485. conv5, conv4, conv3, conv2, conv1 = encode(inputs)
  486. output = decode(conv5, conv4, conv3, conv2, conv1)
  487. model = Model(inputs, output)
  488. # ~ autoinit test. Uncomment to add the autoinit thingy
  489. # ~ model = AutoInit().initialize_model(model)
  490. # ~ model.compile(optimizer = Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=["acc"])
  491. model.compile(
  492. # ~ optimizer = "adam",
  493. optimizer = Adam(),
  494. loss = "binary_crossentropy",
  495. # ~ loss = jaccardLoss,
  496. metrics = ["acc", jaccardIndex, diceIndex])
  497. return model
  498. #dropout increase in middle to reduce runtime in addition to dropping out stuff.
  499. def encode(inputs):
  500. sfilter = GLOBAL_INITIAL_FILTERS
  501. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(inputs)
  502. conv1 = Dropout(0.1)(conv1)
  503. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv1)
  504. pool1 = MaxPooling2D((2, 2))(conv1)
  505. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(pool1)
  506. conv2 = Dropout(0.1)(conv2)
  507. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv2)
  508. pool2 = MaxPooling2D((2, 2))(conv2)
  509. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(pool2)
  510. conv3 = Dropout(0.2)(conv3)
  511. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv3)
  512. pool3 = MaxPooling2D((2, 2))(conv3)
  513. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(pool3)
  514. conv4 = Dropout(0.2)(conv4)
  515. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv4)
  516. pool4 = MaxPooling2D((2, 2))(conv4)
  517. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(pool4)
  518. conv5 = Dropout(0.3)(conv5)
  519. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(conv5)
  520. return conv5, conv4, conv3, conv2, conv1
  521. def decode(conv5, conv4, conv3, conv2, conv1):
  522. sfilter = GLOBAL_INITIAL_FILTERS
  523. up6 = Conv2DTranspose(sfilter * 8, (2, 2), strides = (2, 2), padding = "same")(conv5)
  524. concat6 = Concatenate()([conv4,up6])
  525. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(concat6)
  526. conv6 = Dropout(0.2)(conv6)
  527. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv6)
  528. up7 = Conv2DTranspose(sfilter * 4, (2, 2), strides = (2, 2), padding = "same")(conv6)
  529. concat7 = Concatenate()([conv3,up7])
  530. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(concat7)
  531. conv7 = Dropout(0.2)(conv7)
  532. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv7)
  533. up8 = Conv2DTranspose(sfilter * 2, (2, 2), strides = (2, 2), padding = "same")(conv7)
  534. concat8 = Concatenate()([conv2,up8])
  535. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(concat8)
  536. conv8 = Dropout(0.1)(conv8)
  537. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv8)
  538. up9 = Conv2DTranspose(sfilter, (2, 2), strides = (2, 2), padding = "same")(conv8)
  539. concat9 = Concatenate()([conv1,up9])
  540. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(concat9)
  541. conv9 = Dropout(0.1)(conv9)
  542. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv9)
  543. conv10 = Conv2D(1, (1, 1), padding = "same", activation = "sigmoid")(conv9)
  544. return conv10
  545. def saveExperimentImages(trainImages, trainTruth, testImages, testTruths, tmpFolder):
  546. if not os.path.exists(tmpFolder):
  547. print("Making a tmp folder...")
  548. os.system("mkdir tmp")
  549. print("Done!")
  550. np.save(tmpFolder + "train-images-object", trainImages)
  551. np.save(tmpFolder + "train-truth-object", trainTruth)
  552. np.save(tmpFolder + "test-images-object", testImages)
  553. np.save(tmpFolder + "test-truth-object", testTruths)
  554. #gets images from file, manipulates, returns
  555. #currently hardcoded to use 2017 as training data, and 2016 as testing
  556. #data because their names are regular!
  557. def createTrainAndTestSets(trainImagePath, trainTruthPath, testImagePath, testTruthPath):
  558. trainImageFileNames, trainTruthFileNames, \
  559. testImageFileNames, testTruthFileNames = getFileNames(trainImagePath, trainTruthPath, testImagePath, testTruthPath)
  560. trainImages, trainTruth, _, _ = getImageAndTruth(trainImageFileNames, trainTruthFileNames)
  561. trainTruth = convertImagesToGrayscale(trainTruth)
  562. testImage, testTruth, wholeOriginals, wholeTruths = getImageAndTruth(testImageFileNames, testTruthFileNames)
  563. testTruth = convertImagesToGrayscale(testTruth)
  564. wholeTruths = convertImagesToGrayscaleList(wholeTruths)
  565. #invert the imported images. Tensorflow counts white as truth
  566. #and black as false. I had been doing the inverse previously.
  567. trainImages = invertImagesInArray(trainImages)
  568. trainTruth = invertImagesInArray(trainTruth)
  569. testImage = invertImagesInArray(testImage)
  570. testTruth = invertImagesInArray(testTruth)
  571. wholeOriginals = invertImagesInArray(wholeOriginals)
  572. wholeTruths = invertImagesInArray(wholeTruths)
  573. return trainImages, trainTruth, testImage, testTruth, wholeOriginals, wholeTruths
  574. #Inverts all the images in an array. returns an array.
  575. def invertImagesInArray(imgArray):
  576. for i in range(len(imgArray)):
  577. imgArray[i] = invert(imgArray[i])
  578. return imgArray
  579. #This function gets the source image, cuts it into smaller squares, then
  580. #adds each square to an array for output. The original image squares
  581. #will correspond to the base truth squares.
  582. #Try using a method from here to avoid using lists on the arrays:
  583. #https://stackoverflow.com/questions/50226821/how-to-extend-numpy-arrray
  584. #Also returns a copy of the original uncut images as lists.
  585. def getImageAndTruth(originalFilenames, truthFilenames):
  586. outOriginals, outTruths = [], []
  587. wholeOriginals = []
  588. wholeTruths = []
  589. print("Importing " + originalFilenames[0] + " and friends...")
  590. for i in tqdm(range(len(originalFilenames))):
  591. # ~ print("\rImporting " + originalFilenames[i] + "...", end = "")
  592. myOriginal = imread(originalFilenames[i])[:, :, :3] #this is pretty arcane. research later
  593. myTruth = imread(truthFilenames[i])[:, :, :3] #this is pretty arcane. research later
  594. #save original images as list for returning to main
  595. thisOriginal = myOriginal ##Test before removing these temp vals.
  596. thisTruth = myTruth
  597. wholeOriginals.append(np.asarray(thisOriginal))
  598. wholeTruths.append(np.asarray(thisTruth))
  599. #Now make the cuts and save the results to a list. Then later convert list to array.
  600. originalCuts = cutImageIntoSmallSquares(myOriginal)
  601. truthCuts = cutImageIntoSmallSquares(myTruth)
  602. #for loop to add cuts to out lists, or I think I remember a one liner to do it?
  603. #yes-- list has the .extend() function. it adds the elements of a list to another list.
  604. outOriginals.extend(originalCuts)
  605. outTruths.extend(truthCuts)
  606. #can move to return line later maybe.
  607. outOriginals, outTruths = np.asarray(outOriginals), np.asarray(outTruths)
  608. return outOriginals, outTruths, wholeOriginals, wholeTruths
  609. #Cut an image into smaller squares, returns them as a list.
  610. #inspiration from:
  611. #https://stackoverflow.com/questions/5953373/how-to-split-image-into-multiple-pieces-in-python#7051075
  612. #Change to using numpy methods later for much speed-up?:
  613. #https://towardsdatascience.com/efficiently-splitting-an-image-into-tiles-in-python-using-numpy-d1bf0dd7b6f7?gi=2faa21fa5964
  614. #The input is in scikit-image format. It is converted to pillow to crop
  615. #more easily and for saving???. Then converted back for the output list.
  616. #Whitespace is appended to the right and bottom of the image so that the crop will include everything.
  617. def cutImageIntoSmallSquares(skImage):
  618. skOutList = []
  619. myImage = Image.fromarray(skImage)
  620. imageWidth, imageHeight = myImage.size
  621. tmpW = ((imageWidth // HACK_SIZE) + 1) * HACK_SIZE
  622. tmpH = ((imageHeight // HACK_SIZE) + 1) * HACK_SIZE
  623. #Make this next line (0,0,0) once you switch the words to white and background to black.........##############################################################################
  624. tmpImg = Image.new(myImage.mode, (tmpW, tmpH), (255, 255, 255))
  625. wHehe, hHehe = myImage.size
  626. heheHack = (0, 0, wHehe, hHehe)
  627. # ~ tmpImg.paste(myImage, myImage.getbbox())
  628. if IS_GLOBAL_PRINTING_ON:
  629. print("tmpImg.mode: " + str(tmpImg.mode))
  630. print("tmpImg.getbbox(): " + str(tmpImg.getbbox()))
  631. print("tmpImg.size: " + str(tmpImg.size))
  632. print("myImage.mode: " + str(myImage.mode))
  633. print("myImage.getbbox(): " + str(myImage.getbbox()))
  634. print("myImage width, height: " + "(" + str(imageWidth) + "," + str(imageHeight) + ")")
  635. print("myImage.size: " + str(myImage.size))
  636. print("heheHack: " + str(heheHack))
  637. tmpImg.paste(myImage, heheHack)
  638. myImage = tmpImg
  639. # ~ tmp2 = np.asarray(myImage)
  640. # ~ imshow(tmp2)
  641. # ~ plt.show()
  642. for upper in range(0, imageHeight, HACK_SIZE):
  643. lower = upper + HACK_SIZE
  644. for left in range(0, imageWidth, HACK_SIZE):
  645. right = left + HACK_SIZE
  646. cropBounds = (left, upper, right, lower)
  647. cropped = myImage.crop(cropBounds)
  648. cropped = np.asarray(cropped)
  649. skOutList.append(cropped)
  650. # ~ imshow(cropped / 255)
  651. # ~ plt.show()
  652. return skOutList
  653. #This function cuts a large input image into little squares, uses the
  654. #trained model to predict the binarization of each, then stitches each
  655. #image back into a whole for output.
  656. def predictWholeImage(inputImage, theModel, squareSize):
  657. if IS_GLOBAL_PRINTING_ON:
  658. print("squareSize: " + str(squareSize))
  659. ##get dimensions of the image
  660. height, width, _ = inputImage.shape
  661. ##get the number of squares per row of the image
  662. squaresWide = (width // squareSize) + 1
  663. widthPlusRightBuffer = squaresWide * squareSize
  664. squaresHigh = (height // squareSize) + 1
  665. heightPlusBottomBumper = squaresHigh * squareSize
  666. #Dice the image into bits
  667. if IS_GLOBAL_PRINTING_ON:
  668. print("shape of input Image right before dicing: " + str(np.shape(inputImage)))
  669. # ~ print("input Image right before dicing as string: " + str(inputImage))
  670. dicedImage = cutImageIntoSmallSquares(inputImage)
  671. # ~ print("shape of dicedImage right before hacking: " + str(np.shape(dicedImage)))
  672. # ~ #put output into list with extend then np.asaray the whole list to match elswhere.
  673. # ~ tmpList = []
  674. # ~ for i in range(len(dicedImage)):
  675. # ~ tmpList.extend(dicedImage[i])
  676. # ~ dicedImage = np.asarray(tmpList)
  677. ##Predict the outputs of each square
  678. dicedImage = np.asarray(dicedImage)
  679. # ~ print("shape of dicedImage right before predict: " + str(np.shape(dicedImage)))
  680. # ~ print("dicedImage right before predict as string: " + str(dicedImage))
  681. modelOut = theModel.predict(dicedImage)
  682. ##This is the code from main. I know it's bad now, but I'll keep it
  683. ##consistent until I create a helper function for it. ######################################################################################################
  684. binarizedOuts = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8)
  685. #Stitch image using dimensions from above
  686. #combine each image row into numpy array
  687. theRowsList = []
  688. # ~ print("squaresHigh: " + str(squaresHigh))
  689. # ~ print("squaresWide: " + str(squaresWide))
  690. # ~ print("squareSize: " + str(squareSize))
  691. bigOut = np.zeros(shape = (squareSize * squaresHigh, squareSize * squaresWide, 1), dtype = np.uint8) #swap h and w?
  692. for i in range(squaresHigh):
  693. for j in range(squaresWide):
  694. # ~ print("i: " + str(i) + "\tj: " + str(j))
  695. # ~ print("sqHi: " + str(squaresHigh) + "\tsqWi: " + str(squaresWide))
  696. thisSquare = binarizedOuts[(i * squaresWide) + j] #w? I got an index out of range here.
  697. iStart = i * squareSize
  698. iEnd = (i * squareSize) + squareSize
  699. jStart = j * squareSize
  700. jEnd = (j * squareSize) + squareSize
  701. bigOut[iStart : iEnd , jStart : jEnd ] = thisSquare
  702. # ~ combined = np.asarray(theRowsList)
  703. # ~ combined = combined.reshape((64,64,1))
  704. #Remove the extra padding from the edge of the image.
  705. outImage = bigOut[ :height, :width]
  706. # ~ outImage = bigOut
  707. return outImage
  708. def convertImagesToGrayscale(inputImages):
  709. outImage = []
  710. for image in inputImages:
  711. outImage.append( rgb2gray(image) )
  712. return np.asarray(outImage)
  713. #Returns a list instead of an np array
  714. def convertImagesToGrayscaleList(inputImages):
  715. outImage = []
  716. for image in inputImages:
  717. outImage.append( np.asarray(rgb2gray(image)) )
  718. return outImage
  719. #returns the filenames of the images for (trainImage, trainTruth),(testimage, testTruth)
  720. #hardcoded!
  721. #Test is currently hardcodded to 2016
  722. def getFileNames(trainImagePath, trainTruthPath, testImagePath, testTruthPath):
  723. trainTruthNamePairs = []
  724. # ~ trainImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  725. # ~ trainTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  726. trainTruthNamePairs.append( (trainImagePath, trainTruthPath) )
  727. # ~ trainImageFileNames, trainTruthFileNames = \
  728. # ~ createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath)
  729. #need to handle non-bmp
  730. # ~ trainPath = "../DIBCO/2009/DIBC02009_Test_images-handwritten/"
  731. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_handwritten/"
  732. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  733. # ~ trainPath = "../DIBCO/2009/DIBCO2009_Test_images-printed/"
  734. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_printed/"
  735. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  736. #non-bmps
  737. # ~ trainPath = "../DIBCO/2010/DIBC02010_Test_images/"
  738. # ~ gtPath = "../DIBCO/2010/DIBC02010_Test_GT/"
  739. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  740. #2011 needs a special function to split the GTs with a wildcard or something.
  741. #2012 same
  742. # ~ trainPath = "../DIBCO/2013/OriginalImages/"
  743. # ~ gtPath = "../DIBCO/2013/GTimages/"
  744. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  745. # ~ trainPath = "../DIBCO/2014/original_images/"
  746. # ~ gtPath = "../DIBCO/2014/gt/"
  747. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  748. trainImageFileNames = []
  749. trainTruthFileNames = []
  750. for pair in trainTruthNamePairs:
  751. tImPath, gtImPath = pair
  752. trainNames, gtNames = \
  753. createTrainImageAndTrainTruthFileNames(tImPath, gtImPath)
  754. trainImageFileNames.extend(trainNames)
  755. trainTruthFileNames.extend(gtNames)
  756. #test image section
  757. # ~ testImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  758. # ~ testTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  759. testImageFileNames, testTruthFileNames = \
  760. createTrainImageAndTrainTruthFileNames(testImagePath, testTruthPath)
  761. return trainImageFileNames, trainTruthFileNames, \
  762. testImageFileNames, testTruthFileNames
  763. def createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath):
  764. trainImageFileNames = createTrainImageFileNamesList(trainImagePath)
  765. trainTruthFileNames = createTrainTruthFileNamesList(trainImageFileNames)
  766. trainImageFileNames = appendBMP(trainImageFileNames)
  767. # ~ print(trainImageFileNames)
  768. trainTruthFileNames = appendBMP(trainTruthFileNames)
  769. # ~ print(trainTruthFileNames)
  770. trainImageFileNames = prependPath(trainImagePath, trainImageFileNames)
  771. trainTruthFileNames = prependPath(trainTruthPath, trainTruthFileNames)
  772. return trainImageFileNames, trainTruthFileNames
  773. def createTrainImageFileNamesList(trainImagePath):
  774. # ~ trainFileNames = next(os.walk(trainImagePath))[2] #this is a clever hack
  775. # ~ trainFileNames = [name.replace(".bmp", "") for name in trainFileNames]
  776. trainFileNames = os.listdir(trainImagePath)
  777. print(trainFileNames)
  778. # ~ print("pausing...")
  779. # ~ a = input()
  780. return [name.replace(".bmp", "") for name in trainFileNames]
  781. #This makes a list with the same order of the names but with _gt apended.
  782. def createTrainTruthFileNamesList(originalNames):
  783. return [name + "_gt" for name in originalNames]
  784. def appendBMP(inputList):
  785. return [name + ".bmp" for name in inputList]
  786. def prependPath(myPath, nameList):
  787. return [os.path.join(myPath, name) for name in nameList]
  788. #I'm copying the code for jaccard similarity and dice from this MIT licenced source.
  789. #https://github.com/masyagin1998/robin
  790. #jaccard is size intersection of the sets / size union of the sets
  791. #Also, I'm going to try the smoothing values suggested in robin and here:
  792. #https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
  793. #They also suggest abs()
  794. def jaccardIndex(truth, prediction):
  795. #they are tensors?! not images?!?!?!?!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  796. # ~ truth = img_as_bool(np.asarray(truth)) #fail
  797. # ~ prediction = img_as_bool(np.asarray(prediction)) #fail
  798. smooth = GLOBAL_SMOOTH_JACCARD
  799. predictionFlat = backend.flatten(prediction)
  800. truthFlat = backend.flatten(truth)
  801. # ~ for i in range(len(predictionFlat)):
  802. # ~ if predictionFlat[i] >= 2:
  803. # ~ print("This was the bug. images have non- binary values...#############################################################################################################################################################")
  804. # ~ print("predictionFlat[" + str(i) + "]: " + str(predictionFlat[i]))
  805. # ~ intersectionImg = predictionFlat * truthFlat
  806. numberPixelsSame = backend.sum(truthFlat * predictionFlat)
  807. #I've found the function tensorflow.reduce_sum() which performs a sum by reduction
  808. #Is it better than backend.sum?? ##################################################################
  809. #the docs say it is equivalent except that numpy will change everything to int64
  810. return float((numberPixelsSame + smooth) / \
  811. ( \
  812. (backend.sum(predictionFlat) + backend.sum(truthFlat) - numberPixelsSame + smooth) \
  813. ))
  814. #loss function for use in training.
  815. def jaccardLoss(truth, prediction):
  816. return 1.0 - jaccardIndex(truth, prediction)
  817. #input must be binarized images consisting of values for pixels of either 1 or 0.
  818. def diceIndex(truth, prediction):
  819. smooth = GLOBAL_SMOOTH_DICE
  820. predictionFlat = backend.flatten(prediction)
  821. truthFlat = backend.flatten(truth)
  822. numberSamePixels = backend.sum(predictionFlat * truthFlat)
  823. return float((2 * numberSamePixels + smooth) \
  824. / (backend.sum(predictionFlat) + backend.sum(truthFlat) + smooth))
  825. #Loss function for use in training
  826. def diceLoss(truth, prediction):
  827. smooth = GLOBAL_SMOOTH_DICE
  828. return smooth - diceIndex(truth, prediction)
  829. #This creates a true positive mask from an inverted image (white pixels are text)
  830. def createMaskTruePositive(truth, prediction):
  831. pFlat = backend.flatten(prediction)
  832. pFlat = img_as_bool(pFlat)
  833. tFlat = backend.flatten(truth)
  834. tFlat = img_as_bool(tFlat)
  835. mask = pFlat * tFlat
  836. return np.reshape(mask, np.shape(prediction))
  837. #creaes true negative mask
  838. def createMaskTrueNegative(truth, prediction):
  839. pFlat = backend.flatten(prediction)
  840. pFlat = img_as_bool(pFlat)
  841. tFlat = backend.flatten(truth)
  842. tFlat = img_as_bool(tFlat)
  843. ##invert to make the 0 into ones
  844. pFlat = ~pFlat
  845. tFlat = ~tFlat
  846. ##then multiply
  847. mask = pFlat * tFlat
  848. return np.reshape(mask, np.shape(prediction))
  849. #Creates a mask for the False Positives
  850. def createMaskFalsePositive(truth, prediction):
  851. pFlat = backend.flatten(prediction)
  852. pFlat = img_as_bool(pFlat)
  853. tFlat = backend.flatten(truth)
  854. tFlat = img_as_bool(tFlat)
  855. #will I need these?
  856. # ~ falseArray = np.zeros(len(pFlat), dtype = bool)
  857. # ~ trueArray = np.ones(len(pFlat), dtype = bool)
  858. ##where is the prediction true 1, where the truth is 0 false?
  859. mask = np.where(pFlat > tFlat, True, False)
  860. # ~ mask = np.where(pFlat > tFlat, pFlat, ~pFlat)
  861. return np.reshape(mask, np.shape(prediction))
  862. #returns a mask of all the pixels that are not supposed to be false.
  863. def createMaskFalseNegative(truth, prediction):
  864. # ~ return createMaskFalsePositive(prediction, truth) #Just swapped the input!?? yes but bug happens. 3-1
  865. pFlat = backend.flatten(prediction)
  866. pFlat = img_as_bool(pFlat)
  867. tFlat = backend.flatten(truth)
  868. tFlat = img_as_bool(tFlat)
  869. mask = np.where(pFlat < tFlat, True, False)
  870. return np.reshape(mask, np.shape(prediction))
  871. #Color the prediction image with the pixels that are correct in red.
  872. def colorPredictionWithPredictionMask(predictionMask, originalPrediction, colorArray):
  873. prediction = img_as_bool(originalPrediction)
  874. prediction = np.where( predictionMask >= prediction, True, False ) #This makes the area to paint to white.
  875. prediction = img_as_float(prediction)
  876. predictionMask = np.squeeze(predictionMask, axis = 2)
  877. if IS_GLOBAL_PRINTING_ON:
  878. print("predictionMask shape: " + str(predictionMask.shape))
  879. rows, cols = predictionMask.shape
  880. colorMask = np.zeros((rows, cols, 3))
  881. colorMask[ predictionMask ] = colorArray
  882. predictionInColor = np.dstack((prediction, prediction, prediction))
  883. predictionColor_hsv = rgb2hsv(predictionInColor)
  884. colorMask_hsv = rgb2hsv(colorMask)
  885. alpha = 1.0
  886. predictionColor_hsv[..., 0] = colorMask_hsv[..., 0]
  887. predictionColor_hsv[..., 1] = colorMask_hsv[..., 1] * alpha
  888. outImg = hsv2rgb(predictionColor_hsv)
  889. return img_as_ubyte(outImg)
  890. def combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask):
  891. redColor = [1, 0, 0]
  892. greenColor = [0, 1, 0]
  893. blueColor = [0, 0, 1]
  894. yellowColor = [1, 1, 0]
  895. truePosMask = np.squeeze(truePosMask, axis = 2)
  896. trueNegMask = np.squeeze(trueNegMask, axis = 2)
  897. falsePosMask = np.squeeze(falsePosMask, axis = 2)
  898. falseNegMask = np.squeeze(falseNegMask, axis = 2)
  899. ##make a numpy array of ONES, reshape to image size, then convert with imgtofloat
  900. # ~ print("truePosMask.shape: " + str(truePosMask.shape))
  901. rows, cols = truePosMask.shape
  902. # ~ prediction = np.ones( (rows, cols), dtype=bool )
  903. prediction = np.zeros( (rows, cols), dtype=bool )
  904. # ~ prediction = img_as_float(prediction)
  905. predictionRGB = np.dstack((prediction, prediction, prediction))
  906. # ~ predictionColor_hsv = rgb2hsv(predictionRGB)
  907. ## make the four color masks
  908. redColorMask = img_as_bool(falseNegMask)
  909. blueColorMask = img_as_bool(falsePosMask)
  910. greenColorMask = img_as_bool(truePosMask)
  911. yellowColorMask = img_as_bool(trueNegMask)
  912. predictionRGB[redColorMask, 0 ] = 1
  913. predictionRGB[greenColorMask, 1 ] = 1
  914. predictionRGB[blueColorMask, 2 ] = 1
  915. predictionRGB[yellowColorMask, 0 ] = 1
  916. predictionRGB[yellowColorMask, 1 ] = 1
  917. # ~ green_image = rgb_image.copy() # Make a copy
  918. # ~ green_image[:,:,0] = 0
  919. # ~ green_image[:,:,2] = 0
  920. # ~ alpha = 1.0
  921. # ~ predictionColor_hsv[..., 0] = redColorMask[..., 0]
  922. # ~ predictionColor_hsv[..., 1] = redColorMask[..., 1] * alpha
  923. # ~ predictionColor_hsv[..., 0] = blueColorMask[..., 0]
  924. # ~ predictionColor_hsv[..., 1] = blueColorMask[..., 1] * alpha
  925. # ~ predictionColor_hsv[..., 0] = greenColorMask[..., 0]
  926. # ~ predictionColor_hsv[..., 1] = greenColorMask[..., 1] * alpha
  927. # ~ predictionColor_hsv[..., 0] = yellowColorMask[..., 0]
  928. # ~ predictionColor_hsv[..., 1] = yellowColorMask[..., 1] * alpha
  929. # ~ outImg = hsv2rgb(predictionColor_hsv)
  930. # ~ return img_as_ubyte(outImg)
  931. return img_as_ubyte(predictionRGB)
  932. def makeThisColorMaskHsv(predictionMask, colorArray, rows, cols):
  933. return rgb2hsv(makeThisColorMaskRGB(predictionMask, colorArray, rows, cols))
  934. def makeThisColorMaskRGB(predictionMask, colorArray, rows, cols):
  935. thisColorMask = np.zeros((rows, cols, 3))
  936. thisColorMask[ predictionMask ] = colorArray
  937. return thisColorMask
  938. if __name__ == '__main__':
  939. import sys
  940. sys.exit(main(sys.argv))