my-unet.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # my-unet.py
  5. #
  6. # Copyright 2022 Stephen Stengel <stephen@cwu.edu> MIT License
  7. #
  8. #OK to start out I'll use 2017 as test data, and 2016 as evaluation data.
  9. #Also I will resize each image to a square of size 128x128; I can add more complex manipulation later.
  10. #using these as reference for some parts:
  11. #https://github.com/dexterfichuk/U-Net-Keras-Implementation
  12. print("Running imports...")
  13. import os
  14. import shutil
  15. import numpy as np
  16. import random
  17. from PIL import Image
  18. from tqdm import tqdm
  19. from skimage.io import imread, imshow, imsave
  20. from skimage.transform import resize
  21. from matplotlib import pyplot as plt
  22. from skimage.util import img_as_bool, img_as_float, img_as_uint, img_as_ubyte
  23. from skimage.util import invert
  24. from skimage.color import rgb2gray, gray2rgb, rgb2hsv, hsv2rgb
  25. from sklearn.metrics import auc
  26. import tensorflow as tf
  27. from keras.layers import Conv2D, MaxPool2D, UpSampling2D, Input, Dropout, Lambda, MaxPooling2D, Conv2DTranspose, Concatenate, Softmax
  28. from tensorflow.keras.optimizers import Adam
  29. from keras import Model, callbacks
  30. from keras import backend
  31. # ~ from rezaunet import BCDU_net_D3
  32. ## autoinit stuff
  33. # ~ from autoinit import AutoInit
  34. np.random.seed(55555)
  35. random.seed(55555)
  36. NUM_SQUARES = 100 #Reduced number of square inputs for training. 100 seems to be min for ok results.
  37. HACK_SIZE = 64 #64 is reasonably good for prototyping.
  38. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  39. IMAGE_CHANNELS = 3 #This might change later for different datasets. idk.
  40. GLOBAL_EPOCHS = 15
  41. GLOBAL_BATCH_SIZE = 4 #just needs to be big enough to fill memory
  42. #64hack, 5 epoch, 16batch nearly fills 8gb on laptop. Half of 16 on other laptop.
  43. #Making batch too high seems to cause problems. 32 caused a NaN error when trying to write the output images on laptop1.
  44. GLOBAL_INITIAL_FILTERS = 16
  45. GLOBAL_SMOOTH_JACCARD = 1
  46. GLOBAL_SMOOTH_DICE = 1
  47. IS_GLOBAL_PRINTING_ON = False
  48. # ~ IS_GLOBAL_PRINTING_ON = True
  49. GLOBAL_SQUARE_TEST_SAVE = True
  50. # ~ GLOBAL_SQUARE_TEST_SAVE = False
  51. GLOBAL_MAX_TEST_SQUARE_TO_SAVE = 66
  52. HELPFILE_PATH = os.path.normpath("helpfile")
  53. OUT_TEXT_PATH = os.path.normpath("accuracies-if-error-happens-lol")
  54. print("Done!")
  55. def main(args):
  56. print("Hi!")
  57. checkArgs(args)
  58. print("Creating folders to store results...")
  59. tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, \
  60. predictionsFolder, wholePredictionsFolder, outTextPath \
  61. = createFolders()
  62. print("Done!")
  63. print("Creating copy of source code and conda environment...")
  64. copySourceEnv(tmpFolder)
  65. print("Done!")
  66. print("Creating train and test sets...")
  67. trainImages, trainTruth, testImages, testTruths, wholeOriginals, wholeTruths = createTrainAndTestSets()
  68. print("Done!")
  69. #Images not currently called from disk. Commenting for speed testing.
  70. # ~ saveExperimentImages(trainImages, trainTruth, testImages, testTruths, trainingFolder)
  71. if IS_GLOBAL_PRINTING_ON:
  72. mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths)
  73. trainImages, trainTruth, testImages, testTruths \
  74. = reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, NUM_SQUARES)
  75. theModel, theHistory = trainUnet(trainImages, trainTruth, checkpointFolder)
  76. print("Saving model...")
  77. theModel.save(os.path.join(savedModelFolder, "saved-model.h5"))
  78. print("Done!")
  79. performEvaluation(theHistory, tmpFolder, testImages, testTruths, theModel)
  80. print("shape of testImages right before predict: " + str(np.shape(testImages)))
  81. modelOut = theModel.predict(testImages)
  82. binarizedOut = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8) #######test this thing more
  83. if GLOBAL_SQUARE_TEST_SAVE:
  84. saveTestSquares(
  85. GLOBAL_MAX_TEST_SQUARE_TO_SAVE, modelOut, \
  86. binarizedOut, testImages, testTruths, predictionsFolder)
  87. else:
  88. print("Not saving test square pictures this time.")
  89. print("Calculating jaccard and dice for the test squares...")
  90. calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut)
  91. print("Predicting output of whole images...")
  92. #currently also does the image processing and saving.
  93. predictionsList = predictAllWholeImages(wholeOriginals, wholeTruths, theModel, HACK_SIZE)
  94. print("Creating confusion masks...")
  95. confusionImages, tpList, fpList, tnList, fnList \
  96. = createConfusionImageList(predictionsList, wholeOriginals, wholeTruths)
  97. print("Saving whole image predictions and confusion images...")
  98. saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusionImages, wholePredictionsFolder)
  99. print("Creating ROC graph of the whole images test...")
  100. createROC(tpList, fpList, tnList, fnList, tmpFolder)
  101. print("Evaluating jaccard and dice scores...")
  102. evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath)
  103. print("Done!")
  104. return 0
  105. def createFolders():
  106. sq = str(NUM_SQUARES)
  107. hk = str(HACK_SIZE)
  108. ep = str(GLOBAL_EPOCHS)
  109. ba = str(GLOBAL_BATCH_SIZE)
  110. tmpFolder = os.path.normpath("./tmp" + sq + "-" + hk + "-" + ep + "-" + ba + "/")
  111. trainingFolder = os.path.join(tmpFolder, "trainingstuff")
  112. checkpointFolder = os.path.join(tmpFolder, "checkpoint")
  113. savedModelFolder = os.path.join(tmpFolder, "saved-model")
  114. predictionsFolder = os.path.join(tmpFolder, "predictions")
  115. wholePredictionsFolder = os.path.join(tmpFolder, "whole-predictions")
  116. foldersToCreate = [ \
  117. tmpFolder, trainingFolder, \
  118. checkpointFolder, savedModelFolder, \
  119. predictionsFolder, wholePredictionsFolder]
  120. for folder in foldersToCreate:
  121. if not os.path.isdir(folder):
  122. os.makedirs(folder)
  123. #Lol spaghetti
  124. global OUT_TEXT_PATH
  125. OUT_TEXT_PATH = os.path.join(tmpFolder, "accuracy-jaccard-dice.txt")
  126. return tmpFolder, trainingFolder, checkpointFolder, savedModelFolder, predictionsFolder, wholePredictionsFolder, OUT_TEXT_PATH
  127. def copySourceEnv(tmpFolder):
  128. try:
  129. shutil.copy("my-unet.py", tmpFolder)
  130. shutil.copy("working-conda-config.yml", tmpFolder)
  131. except:
  132. print("copy error! Source code and conda environment file not copied!")
  133. #This would be preferred in the final product.
  134. # ~ print("Creating current copy of environment...")
  135. # ~ os.system("conda env export > " + tmpFolder + "working-conda-config-current.yml")
  136. # ~ print("Done!")
  137. def reduceInputForTesting(trainImages, trainTruth, testImages, testTruths, sizeOfSet ):
  138. #This block reduces the input for testing.
  139. highIndex = len(trainImages)
  140. if sizeOfSet > highIndex + 1: #Just in case user enters more squares than exist.
  141. sizeOfSet = highIndex + 1
  142. print("! Limiting size of squares for training to actual number of squares !")
  143. print("Number of squares to be used for training: " + str(sizeOfSet))
  144. updateGlobalNumSquares(sizeOfSet)
  145. rng = np.random.default_rng(12345)
  146. pickIndexes = rng.integers(low = 0, high = highIndex, size = sizeOfSet)
  147. trainImages = trainImages[pickIndexes]
  148. trainTruth = trainTruth[pickIndexes]
  149. sizeOfTestSet = sizeOfSet
  150. if sizeOfTestSet > len(testImages):
  151. sizeOfTestSet = len(testImages)
  152. rng = np.random.default_rng(23456)
  153. print("sizeOfTestSet: " + str(sizeOfTestSet))
  154. pickIndexes = rng.integers(low = 0, high = len(testImages), size = sizeOfTestSet)
  155. testImages = testImages[pickIndexes]
  156. testTruths = testTruths[pickIndexes]
  157. print("There are " + str(len(trainImages)) + " training images.")
  158. print("There are " + str(len(testImages)) + " testing images.")
  159. return trainImages, trainTruth, testImages, testTruths
  160. def mainTestPrintOne(wholeOriginals, wholeTruths, trainImages, trainTruth, testImages, testTruths):
  161. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  162. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  163. print("shape of trainImages: " + str(np.shape(trainImages)))
  164. print("shape of trainTruth: " + str(np.shape(trainTruth)))
  165. print("shape of testImages: " + str(np.shape(testImages)))
  166. print("shape of testTruths: " + str(np.shape(testTruths)))
  167. print("Showing Training stuff...")
  168. randomBoy = random.randint(0, len(trainImages) - 1)
  169. print("image " + str(randomBoy) + "...")
  170. imshow(trainImages[randomBoy] / 255)
  171. plt.show()
  172. print("truth " + str(randomBoy) + "...")
  173. imshow(np.squeeze(trainTruth[randomBoy]))
  174. plt.show()
  175. print("Showing Testing stuff...")
  176. randomBoy = random.randint(0, len(testImages) - 1)
  177. print("image " + str(randomBoy) + "...")
  178. imshow(testImages[randomBoy] / 255)
  179. plt.show()
  180. print("truth " + str(randomBoy) + "...")
  181. imshow(np.squeeze(testTruths[randomBoy]))
  182. plt.show()
  183. #save copies of some of the squares used in learning.
  184. def saveTestSquares(numToSave, modelOut, binarizedOut, testImages, testTruths, predictionsFolder):
  185. print("Saving random sample of figures...")
  186. rng2 = np.random.default_rng(54322)
  187. if len(modelOut) < numToSave:
  188. numToSave = len(modelOut)
  189. saveIndexes = rng2.integers(low = 0, high = len(modelOut), size = numToSave)
  190. for i in tqdm(saveIndexes):
  191. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]premask.png"), img_as_ubyte(modelOut[i]))
  192. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]predict.png"), img_as_ubyte(binarizedOut[i]))
  193. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]testimg.png"), img_as_ubyte(testImages[i]))
  194. imsave(os.path.join(predictionsFolder, "fig[" + str(i) + "]truthim.png"), img_as_ubyte(testTruths[i]))
  195. print("Done!")
  196. def calculateJaccardDiceTestSquares(testTruths, outTextPath, binarizedOut):
  197. testTruthsUInt = testTruths.astype(np.uint8)
  198. #Testing the jaccard and dice functions
  199. with open(os.path.join(outTextPath), "w") as outFile:
  200. for i in tqdm(range(len(binarizedOut))):
  201. jac = jaccardIndex(testTruthsUInt[i], binarizedOut[i])
  202. dice = diceIndex(testTruthsUInt[i], binarizedOut[i])
  203. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  204. outFile.write(thisString)
  205. print("Done!")
  206. #currently also does the image processing and saving.
  207. def predictAllWholeImages(wholeOriginals, wholeTruths, theModel, squareSize):
  208. if IS_GLOBAL_PRINTING_ON:
  209. print("shape of wholeOriginals: " + str(np.shape(wholeOriginals)))
  210. for i in range(len(wholeOriginals)):
  211. print(str(np.shape(wholeOriginals[i])))
  212. print("##########################################################")
  213. print("shape of wholeTruths: " + str(np.shape(wholeTruths)))
  214. predictionsList = []
  215. for i in tqdm(range(len(wholeOriginals))):
  216. # ~ wholeOriginals, wholeTruths
  217. predictedImage = predictWholeImage(wholeOriginals[i], theModel, squareSize)
  218. if IS_GLOBAL_PRINTING_ON:
  219. print("Shape of predicted image " + str(i) + ": " + str(np.shape(predictedImage)))
  220. # ~ predictedImage = ((predictedImage > 0.5).astype(np.uint8) * 255).astype(np.uint8) ## jank thing again
  221. # ~ print("Shape of predicted image " + str(i) + " after mask: " + str(np.shape(predictedImage)))
  222. predictionsList.append(predictedImage)
  223. return predictionsList
  224. #This could be split up a bit as well.
  225. #also outputs the binary masks in lists !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! change later
  226. def createConfusionImageList(predictionsList, wholeOriginals, wholeTruths):
  227. confusionList = []
  228. tpList = []
  229. fpList = []
  230. tnList = []
  231. fnList = []
  232. for i in tqdm(range(len(predictionsList))):
  233. predictedImage = predictionsList[i]
  234. truePosMask = createMaskTruePositive(wholeTruths[i], predictedImage)
  235. tpList.append(truePosMask)
  236. trueNegMask = createMaskTrueNegative(wholeTruths[i], predictedImage)
  237. tnList.append(trueNegMask)
  238. falsePosMask = createMaskFalsePositive(wholeTruths[i], predictedImage)
  239. fpList.append(falsePosMask)
  240. falseNegMask = createMaskFalseNegative(wholeTruths[i], predictedImage)
  241. fnList.append(falseNegMask)
  242. redColor = [1, 0, 0]
  243. greenColor = [0, 1, 0]
  244. blueColor = [0, 0, 1]
  245. yellowColor = [1, 1, 0]
  246. truePosColor = colorPredictionWithPredictionMask(truePosMask, predictedImage, greenColor)
  247. trueNegColor = colorPredictionWithPredictionMask(trueNegMask, predictedImage, yellowColor)
  248. falsePosColor = colorPredictionWithPredictionMask(falsePosMask, predictedImage, blueColor)
  249. falseNegColor = colorPredictionWithPredictionMask(falseNegMask, predictedImage, redColor )
  250. confusion = combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask)
  251. confusionList.append(confusion)
  252. return confusionList, tpList, fpList, tnList, fnList
  253. def saveAllWholeAndConfusion(predictionsList, wholeOriginals, wholeTruths, confusions, wholePredictionsFolder):
  254. for i in tqdm(range(len(predictionsList))):
  255. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]predicted.png"), img_as_ubyte(predictionsList[i]))
  256. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]truth.png"), img_as_ubyte(wholeTruths[i]))
  257. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]original.png"), img_as_ubyte(wholeOriginals[i]))
  258. imsave(os.path.join(wholePredictionsFolder, "img[" + str(i) + "]confusion.png"), img_as_ubyte(confusions[i]))
  259. def createROC(tpList, fpList, tnList, fnList, tmpFolder):
  260. #convert each list of masks into a list of percentages.
  261. tpScores = getPercentTrueFromMaskList(tpList)
  262. fpScores = getPercentTrueFromMaskList(fpList)
  263. #not currently needed
  264. # ~ tnScores = getPercentTrueFromMaskList(tnList[i])
  265. # ~ fnScores = getPercentTrueFromMaskList(fnList[i])
  266. plotROCandSave(fpScores, tpScores, tmpFolder)
  267. def plotROCandSave(fpList, tpList, tmpFolder):
  268. # ~ [x for (y,x) in sorted(zip(Y,X), key=lambda pair: pair[0])]
  269. fpArray = np.asarray(fpList)
  270. tpArray = np.asarray(tpList)
  271. indexesSorted = fpArray.argsort()
  272. fpArray = fpArray[indexesSorted]
  273. tpArray = tpArray[indexesSorted]
  274. print("fpArray: " + str(fpArray))
  275. print("tpArray: " + str(tpArray))
  276. # ~ print("fpArray: " + str(fpArray))
  277. # ~ print("tpArray: " + str(tpArray))
  278. roc_auc = auc(fpArray, tpArray)
  279. print("contents of roc_auc: " + str(roc_auc))
  280. # ~ for thing in roc_auc:
  281. # ~ print(thing)
  282. # ~ print("##")
  283. linewidth = 2
  284. plt.figure()
  285. plt.plot(
  286. fpArray,
  287. tpArray,
  288. color="darkorange",
  289. linewidth = linewidth,
  290. label="ROC curve (area = %0.2f%%)" % roc_auc,
  291. # ~ label = "ROC curve"
  292. )
  293. plt.plot([0, 1], [0, 1], color="navy", linewidth=linewidth, linestyle="--")
  294. plt.xlim([0.0, 1.0])
  295. # ~ plt.ylim([0.0, 1.0])
  296. plt.ylim([0.0, 1.05])
  297. plt.xlabel("False Positive Rate")
  298. plt.ylabel("True Positive Rate")
  299. plt.title("ROC curve !")
  300. plt.legend(loc="lower right")
  301. plt.savefig(os.path.join(tmpFolder, "roc-curve.png"))
  302. plt.show()
  303. def getPercentTrueFromMaskList(inList):
  304. scoreList = []
  305. for i in tqdm(range(len(inList))):
  306. scoreList.append( getPercentTrueFromMask(inList[i]))
  307. return scoreList
  308. ######################################################### I actually don't need percent. I neet TPRa and FPR (tp, fp, tn, fn) ##########################
  309. def getPercentTrueFromMask(inMask):
  310. # ~ mFlat = backend.flatten( img_as_uint(inMask) )
  311. # ~ numTrue = backend.sum(mFlat)
  312. mFlat = np.asarray(inMask).flatten()
  313. numTrue = np.sum(mFlat)
  314. # ~ mFlat = np.asarray(mFlat)
  315. totalNum = mFlat.size
  316. return numTrue / totalNum
  317. def evaluatePredictionJaccardDice(predictionsList, wholeTruths, outTextPath):
  318. print("Calculating jaccard and dice...")
  319. with open(outTextPath, "w") as outFile:
  320. for i in tqdm(range(len(predictionsList))):
  321. thisTruth = np.asarray(wholeTruths[i])
  322. thisTruth = thisTruth.astype(np.uint8)
  323. jac = jaccardIndex(thisTruth, predictionsList[i])
  324. dice = diceIndex(thisTruth, predictionsList[i])
  325. thisString = str(i) + "\tjaccard: " + str(jac) + "\tdice: " + str(dice) + "\n"
  326. outFile.write(thisString)
  327. print("Done!")
  328. def checkArgs(args):
  329. if len(args) >= 1:
  330. for a in args:
  331. if str(a) == "help" \
  332. or str(a).lower() == "-help" \
  333. or str(a).lower() == "--help" \
  334. or str(a).lower() == "--h":
  335. with open(HELPFILE_PATH, "r") as helpfile:
  336. for line in helpfile:
  337. print(line, end = "")
  338. sys.exit(0)
  339. if len(args) < 5:
  340. print("bad input");
  341. sys.exit(-1)
  342. else:
  343. global NUM_SQUARES
  344. NUM_SQUARES = int(sys.argv[1])
  345. global HACK_SIZE
  346. HACK_SIZE = int(sys.argv[2])
  347. global GLOBAL_HACK_height
  348. global GLOBAL_HACK_width
  349. GLOBAL_HACK_height, GLOBAL_HACK_width = HACK_SIZE, HACK_SIZE
  350. global GLOBAL_EPOCHS
  351. GLOBAL_EPOCHS = int(sys.argv[3])
  352. global GLOBAL_BATCH_SIZE
  353. GLOBAL_BATCH_SIZE = int(sys.argv[4])
  354. if len(args) >= 6:
  355. if str(sys.argv[5]) == "print":
  356. global IS_GLOBAL_PRINTING_ON
  357. IS_GLOBAL_PRINTING_ON = True
  358. print("Printing of debugging messages is enabled.")
  359. if NUM_SQUARES < 100:
  360. print("100 squares is really the bare minimum to get any meaningful result.")
  361. sys.exit(-1)
  362. if HACK_SIZE not in [64, 128, 256, 512]:
  363. print("Square size must be 64, 128, 256, or 512." \
  364. + " 128 is recommended for training. 64 for testing")
  365. sys.exit(-2)
  366. if GLOBAL_EPOCHS < 1:
  367. print("Yeah no.")
  368. print("You need at least one epoch, silly!")
  369. sys.exit(-3)
  370. if GLOBAL_BATCH_SIZE < 1 or GLOBAL_BATCH_SIZE > NUM_SQUARES:
  371. print("Global batch size should be between 1 and the number" \
  372. + " of training squares. Pick a better number.")
  373. sys.exit(-5)
  374. def updateGlobalNumSquares(newNumSquares):
  375. global NUM_SQUARES
  376. NUM_SQUARES = newNumSquares
  377. def performEvaluation(history, tmpFolder, testImages, testTruths, theModel):
  378. print("Performing evaluation...###############################################################")
  379. print("Calculating scores...")
  380. print("len testImages: " + str(len(testImages)))
  381. scores = theModel.evaluate(testImages, testTruths)
  382. print("Done!")
  383. print("Scores object: " + str(scores))
  384. print(str(history.history))
  385. print("%s: %.2f%%" % (theModel.metrics_names[1], scores[1]*100))
  386. print("history...")
  387. print(history)
  388. print("history.history...")
  389. print(history.history)
  390. accuracy = history.history["acc"]
  391. jaccInd = history.history["jaccardIndex"]
  392. diceInd = history.history["diceIndex"]
  393. val_accuracy = history.history["val_acc"]
  394. val_jaccInd = history.history["val_jaccardIndex"]
  395. val_diceInd = history.history["val_diceIndex"]
  396. loss = history.history["loss"]
  397. val_loss = history.history["val_loss"]
  398. epochs = range(1, len(accuracy) + 1)
  399. plt.plot(epochs, accuracy, "^", label="Training accuracy")
  400. plt.plot(epochs, val_accuracy, "2", label="Validation accuracy")
  401. plt.plot(epochs, jaccInd, "*", label="Jaccard Index")
  402. plt.plot(epochs, val_jaccInd, "p", label="Validation Jaccard Index")
  403. plt.plot(epochs, diceInd, "s", label="Dice Index")
  404. plt.plot(epochs, val_diceInd, "D", label="Validation Dice Index")
  405. plt.title("Training and validation accuracy")
  406. plt.legend()
  407. plt.savefig(os.path.join(tmpFolder, "trainvalacc.png"))
  408. plt.clf()
  409. plt.plot(epochs, loss, "^", label="Training loss")
  410. plt.plot(epochs, val_loss, "2", label="Validation loss")
  411. plt.title("Training and validation loss")
  412. plt.legend()
  413. plt.savefig(os.path.join(tmpFolder, "trainvalloss.png"))
  414. plt.clf()
  415. def trainUnet(trainImages, trainTruth, checkpointFolder):
  416. #print("shape of trainImages: " + str(trainImages.shape))
  417. standardUnetLol = createStandardUnet()
  418. # ~ standardUnetLol = BCDU_net_D3( (GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS) )
  419. standardUnetLol.summary()
  420. checkpointer = callbacks.ModelCheckpoint(
  421. filepath = checkpointFolder,
  422. # ~ monitor = "val_acc", #current working version
  423. # ~ monitor = "val_loss", #original ##################################################
  424. # ~ monitor = "val_jaccardIndex",
  425. monitor = "jaccardIndex",
  426. # ~ monitor = "val_jaccardLoss",
  427. save_best_only = True,
  428. mode = "max")
  429. # ~ mode = "min")
  430. earlyStopper = callbacks.EarlyStopping( \
  431. monitor="val_jaccardIndex", \
  432. patience = 5, \
  433. mode = "max", \
  434. #not sure about resotre weights...
  435. restore_best_weights = True)
  436. callbacks_list = [earlyStopper, checkpointer]
  437. myHistory = standardUnetLol.fit(
  438. x = trainImages,
  439. y = trainTruth,
  440. epochs = GLOBAL_EPOCHS,
  441. batch_size = GLOBAL_BATCH_SIZE,
  442. callbacks = callbacks_list,
  443. validation_split = 0.33333)
  444. return standardUnetLol, myHistory
  445. def createStandardUnet():
  446. input_size=(GLOBAL_HACK_height, GLOBAL_HACK_width, IMAGE_CHANNELS)
  447. inputs = Input(input_size)
  448. conv5, conv4, conv3, conv2, conv1 = encode(inputs)
  449. output = decode(conv5, conv4, conv3, conv2, conv1)
  450. model = Model(inputs, output)
  451. # ~ autoinit test. Uncomment to add the autoinit thingy
  452. # ~ model = AutoInit().initialize_model(model)
  453. # ~ model.compile(optimizer = Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=["acc"])
  454. model.compile(
  455. # ~ optimizer = "adam",
  456. optimizer = Adam(),
  457. loss = "binary_crossentropy",
  458. # ~ loss = jaccardLoss,
  459. metrics = ["acc", jaccardIndex, diceIndex])
  460. return model
  461. #dropout increase in middle to reduce runtime in addition to dropping out stuff.
  462. def encode(inputs):
  463. sfilter = GLOBAL_INITIAL_FILTERS
  464. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(inputs)
  465. conv1 = Dropout(0.1)(conv1)
  466. conv1 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv1)
  467. pool1 = MaxPooling2D((2, 2))(conv1)
  468. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(pool1)
  469. conv2 = Dropout(0.1)(conv2)
  470. conv2 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv2)
  471. pool2 = MaxPooling2D((2, 2))(conv2)
  472. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(pool2)
  473. conv3 = Dropout(0.2)(conv3)
  474. conv3 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv3)
  475. pool3 = MaxPooling2D((2, 2))(conv3)
  476. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(pool3)
  477. conv4 = Dropout(0.2)(conv4)
  478. conv4 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv4)
  479. pool4 = MaxPooling2D((2, 2))(conv4)
  480. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(pool4)
  481. conv5 = Dropout(0.3)(conv5)
  482. conv5 = Conv2D(sfilter * 16, (3, 3), activation = 'relu', padding = "same")(conv5)
  483. return conv5, conv4, conv3, conv2, conv1
  484. def decode(conv5, conv4, conv3, conv2, conv1):
  485. sfilter = GLOBAL_INITIAL_FILTERS
  486. up6 = Conv2DTranspose(sfilter * 8, (2, 2), strides = (2, 2), padding = "same")(conv5)
  487. concat6 = Concatenate()([conv4,up6])
  488. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(concat6)
  489. conv6 = Dropout(0.2)(conv6)
  490. conv6 = Conv2D(sfilter * 8, (3, 3), activation = 'relu', padding = "same")(conv6)
  491. up7 = Conv2DTranspose(sfilter * 4, (2, 2), strides = (2, 2), padding = "same")(conv6)
  492. concat7 = Concatenate()([conv3,up7])
  493. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(concat7)
  494. conv7 = Dropout(0.2)(conv7)
  495. conv7 = Conv2D(sfilter * 4, (3, 3), activation = 'relu', padding = "same")(conv7)
  496. up8 = Conv2DTranspose(sfilter * 2, (2, 2), strides = (2, 2), padding = "same")(conv7)
  497. concat8 = Concatenate()([conv2,up8])
  498. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(concat8)
  499. conv8 = Dropout(0.1)(conv8)
  500. conv8 = Conv2D(sfilter * 2, (3, 3), activation = 'relu', padding = "same")(conv8)
  501. up9 = Conv2DTranspose(sfilter, (2, 2), strides = (2, 2), padding = "same")(conv8)
  502. concat9 = Concatenate()([conv1,up9])
  503. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(concat9)
  504. conv9 = Dropout(0.1)(conv9)
  505. conv9 = Conv2D(sfilter, (3, 3), activation = 'relu', padding = "same")(conv9)
  506. conv10 = Conv2D(1, (1, 1), padding = "same", activation = "sigmoid")(conv9)
  507. return conv10
  508. def saveExperimentImages(trainImages, trainTruth, testImages, testTruths, tmpFolder):
  509. if not os.path.exists(tmpFolder):
  510. print("Making a tmp folder...")
  511. os.system("mkdir tmp")
  512. print("Done!")
  513. np.save(tmpFolder + "train-images-object", trainImages)
  514. np.save(tmpFolder + "train-truth-object", trainTruth)
  515. np.save(tmpFolder + "test-images-object", testImages)
  516. np.save(tmpFolder + "test-truth-object", testTruths)
  517. #gets images from file, manipulates, returns
  518. #currently hardcoded to use 2017 as training data, and 2016 as testing
  519. #data because their names are regular!
  520. def createTrainAndTestSets():
  521. trainImageFileNames, trainTruthFileNames, \
  522. testImageFileNames, testTruthFileNames = getFileNames()
  523. trainImages, trainTruth, _, _ = getImageAndTruth(trainImageFileNames, trainTruthFileNames)
  524. trainTruth = convertImagesToGrayscale(trainTruth)
  525. testImage, testTruth, wholeOriginals, wholeTruths = getImageAndTruth(testImageFileNames, testTruthFileNames)
  526. testTruth = convertImagesToGrayscale(testTruth)
  527. wholeTruths = convertImagesToGrayscaleList(wholeTruths)
  528. #invert the imported images. Tensorflow counts white as truth
  529. #and black as false. I had been doing the inverse previously.
  530. trainImages = invertImagesInArray(trainImages)
  531. trainTruth = invertImagesInArray(trainTruth)
  532. testImage = invertImagesInArray(testImage)
  533. testTruth = invertImagesInArray(testTruth)
  534. wholeOriginals = invertImagesInArray(wholeOriginals)
  535. wholeTruths = invertImagesInArray(wholeTruths)
  536. return trainImages, trainTruth, testImage, testTruth, wholeOriginals, wholeTruths
  537. #Inverts all the images in an array. returns an array.
  538. def invertImagesInArray(imgArray):
  539. for i in range(len(imgArray)):
  540. imgArray[i] = invert(imgArray[i])
  541. return imgArray
  542. #This function gets the source image, cuts it into smaller squares, then
  543. #adds each square to an array for output. The original image squares
  544. #will correspond to the base truth squares.
  545. #Try using a method from here to avoid using lists on the arrays:
  546. #https://stackoverflow.com/questions/50226821/how-to-extend-numpy-arrray
  547. #Also returns a copy of the original uncut images as lists.
  548. def getImageAndTruth(originalFilenames, truthFilenames):
  549. outOriginals, outTruths = [], []
  550. wholeOriginals = []
  551. wholeTruths = []
  552. print("Importing " + originalFilenames[0] + " and friends...")
  553. for i in tqdm(range(len(originalFilenames))):
  554. # ~ print("\rImporting " + originalFilenames[i] + "...", end = "")
  555. myOriginal = imread(originalFilenames[i])[:, :, :3] #this is pretty arcane. research later
  556. myTruth = imread(truthFilenames[i])[:, :, :3] #this is pretty arcane. research later
  557. #save original images as list for returning to main
  558. thisOriginal = myOriginal ##Test before removing these temp vals.
  559. thisTruth = myTruth
  560. wholeOriginals.append(np.asarray(thisOriginal))
  561. wholeTruths.append(np.asarray(thisTruth))
  562. #Now make the cuts and save the results to a list. Then later convert list to array.
  563. originalCuts = cutImageIntoSmallSquares(myOriginal)
  564. truthCuts = cutImageIntoSmallSquares(myTruth)
  565. #for loop to add cuts to out lists, or I think I remember a one liner to do it?
  566. #yes-- list has the .extend() function. it adds the elements of a list to another list.
  567. outOriginals.extend(originalCuts)
  568. outTruths.extend(truthCuts)
  569. #can move to return line later maybe.
  570. outOriginals, outTruths = np.asarray(outOriginals), np.asarray(outTruths)
  571. return outOriginals, outTruths, wholeOriginals, wholeTruths
  572. #Cut an image into smaller squares, returns them as a list.
  573. #inspiration from:
  574. #https://stackoverflow.com/questions/5953373/how-to-split-image-into-multiple-pieces-in-python#7051075
  575. #Change to using numpy methods later for much speed-up?:
  576. #https://towardsdatascience.com/efficiently-splitting-an-image-into-tiles-in-python-using-numpy-d1bf0dd7b6f7?gi=2faa21fa5964
  577. #The input is in scikit-image format. It is converted to pillow to crop
  578. #more easily and for saving???. Then converted back for the output list.
  579. #Whitespace is appended to the right and bottom of the image so that the crop will include everything.
  580. def cutImageIntoSmallSquares(skImage):
  581. skOutList = []
  582. myImage = Image.fromarray(skImage)
  583. imageWidth, imageHeight = myImage.size
  584. tmpW = ((imageWidth // HACK_SIZE) + 1) * HACK_SIZE
  585. tmpH = ((imageHeight // HACK_SIZE) + 1) * HACK_SIZE
  586. #Make this next line (0,0,0) once you switch the words to white and background to black.........##############################################################################
  587. tmpImg = Image.new(myImage.mode, (tmpW, tmpH), (255, 255, 255))
  588. wHehe, hHehe = myImage.size
  589. heheHack = (0, 0, wHehe, hHehe)
  590. # ~ tmpImg.paste(myImage, myImage.getbbox())
  591. if IS_GLOBAL_PRINTING_ON:
  592. print("tmpImg.mode: " + str(tmpImg.mode))
  593. print("tmpImg.getbbox(): " + str(tmpImg.getbbox()))
  594. print("tmpImg.size: " + str(tmpImg.size))
  595. print("myImage.mode: " + str(myImage.mode))
  596. print("myImage.getbbox(): " + str(myImage.getbbox()))
  597. print("myImage width, height: " + "(" + str(imageWidth) + "," + str(imageHeight) + ")")
  598. print("myImage.size: " + str(myImage.size))
  599. print("heheHack: " + str(heheHack))
  600. tmpImg.paste(myImage, heheHack)
  601. myImage = tmpImg
  602. # ~ tmp2 = np.asarray(myImage)
  603. # ~ imshow(tmp2)
  604. # ~ plt.show()
  605. for upper in range(0, imageHeight, HACK_SIZE):
  606. lower = upper + HACK_SIZE
  607. for left in range(0, imageWidth, HACK_SIZE):
  608. right = left + HACK_SIZE
  609. cropBounds = (left, upper, right, lower)
  610. cropped = myImage.crop(cropBounds)
  611. cropped = np.asarray(cropped)
  612. skOutList.append(cropped)
  613. # ~ imshow(cropped / 255)
  614. # ~ plt.show()
  615. return skOutList
  616. #This function cuts a large input image into little squares, uses the
  617. #trained model to predict the binarization of each, then stitches each
  618. #image back into a whole for output.
  619. def predictWholeImage(inputImage, theModel, squareSize):
  620. if IS_GLOBAL_PRINTING_ON:
  621. print("squareSize: " + str(squareSize))
  622. ##get dimensions of the image
  623. height, width, _ = inputImage.shape
  624. ##get the number of squares per row of the image
  625. squaresWide = (width // squareSize) + 1
  626. widthPlusRightBuffer = squaresWide * squareSize
  627. squaresHigh = (height // squareSize) + 1
  628. heightPlusBottomBumper = squaresHigh * squareSize
  629. #Dice the image into bits
  630. if IS_GLOBAL_PRINTING_ON:
  631. print("shape of input Image right before dicing: " + str(np.shape(inputImage)))
  632. # ~ print("input Image right before dicing as string: " + str(inputImage))
  633. dicedImage = cutImageIntoSmallSquares(inputImage)
  634. # ~ print("shape of dicedImage right before hacking: " + str(np.shape(dicedImage)))
  635. # ~ #put output into list with extend then np.asaray the whole list to match elswhere.
  636. # ~ tmpList = []
  637. # ~ for i in range(len(dicedImage)):
  638. # ~ tmpList.extend(dicedImage[i])
  639. # ~ dicedImage = np.asarray(tmpList)
  640. ##Predict the outputs of each square
  641. dicedImage = np.asarray(dicedImage)
  642. # ~ print("shape of dicedImage right before predict: " + str(np.shape(dicedImage)))
  643. # ~ print("dicedImage right before predict as string: " + str(dicedImage))
  644. modelOut = theModel.predict(dicedImage)
  645. ##This is the code from main. I know it's bad now, but I'll keep it
  646. ##consistent until I create a helper function for it. ######################################################################################################
  647. binarizedOuts = ((modelOut > 0.5).astype(np.uint8) * 255).astype(np.uint8)
  648. #Stitch image using dimensions from above
  649. #combine each image row into numpy array
  650. theRowsList = []
  651. # ~ print("squaresHigh: " + str(squaresHigh))
  652. # ~ print("squaresWide: " + str(squaresWide))
  653. # ~ print("squareSize: " + str(squareSize))
  654. bigOut = np.zeros(shape = (squareSize * squaresHigh, squareSize * squaresWide, 1), dtype = np.uint8) #swap h and w?
  655. for i in range(squaresHigh):
  656. for j in range(squaresWide):
  657. # ~ print("i: " + str(i) + "\tj: " + str(j))
  658. # ~ print("sqHi: " + str(squaresHigh) + "\tsqWi: " + str(squaresWide))
  659. thisSquare = binarizedOuts[(i * squaresWide) + j] #w?
  660. iStart = i * squareSize
  661. iEnd = (i * squareSize) + squareSize
  662. jStart = j * squareSize
  663. jEnd = (j * squareSize) + squareSize
  664. bigOut[iStart : iEnd , jStart : jEnd ] = thisSquare
  665. # ~ combined = np.asarray(theRowsList)
  666. # ~ combined = combined.reshape((64,64,1))
  667. #Remove the extra padding from the edge of the image.
  668. outImage = bigOut[ :height, :width]
  669. # ~ outImage = bigOut
  670. return outImage
  671. def convertImagesToGrayscale(inputImages):
  672. outImage = []
  673. for image in inputImages:
  674. outImage.append( rgb2gray(image) )
  675. return np.asarray(outImage)
  676. #Returns a list instead of an np array
  677. def convertImagesToGrayscaleList(inputImages):
  678. outImage = []
  679. for image in inputImages:
  680. outImage.append( np.asarray(rgb2gray(image)) )
  681. return outImage
  682. #returns the filenames of the images for (trainImage, trainTruth),(testimage, testTruth)
  683. #hardcoded!
  684. #Test is currently hardcodded to 2016
  685. def getFileNames():
  686. trainTruthNamePairs = []
  687. trainImagePath = os.path.normpath("../DIBCO/2017/Dataset/")
  688. trainTruthPath = os.path.normpath("../DIBCO/2017/GT/")
  689. trainTruthNamePairs.append( (trainImagePath, trainTruthPath) )
  690. # ~ trainImageFileNames, trainTruthFileNames = \
  691. # ~ createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath)
  692. #need to handle non-bmp
  693. # ~ trainPath = "../DIBCO/2009/DIBC02009_Test_images-handwritten/"
  694. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_handwritten/"
  695. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  696. # ~ trainPath = "../DIBCO/2009/DIBCO2009_Test_images-printed/"
  697. # ~ gtPath = "../DIBCO/2009/DIBCO2009-GT-Test-images_printed/"
  698. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  699. #non-bmps
  700. # ~ trainPath = "../DIBCO/2010/DIBC02010_Test_images/"
  701. # ~ gtPath = "../DIBCO/2010/DIBC02010_Test_GT/"
  702. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  703. #2011 needs a special function to split the GTs with a wildcard or something.
  704. #2012 same
  705. # ~ trainPath = "../DIBCO/2013/OriginalImages/"
  706. # ~ gtPath = "../DIBCO/2013/GTimages/"
  707. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  708. # ~ trainPath = "../DIBCO/2014/original_images/"
  709. # ~ gtPath = "../DIBCO/2014/gt/"
  710. # ~ trainTruthNamePairs.append( (trainPath, gtPath) )
  711. trainImageFileNames = []
  712. trainTruthFileNames = []
  713. for pair in trainTruthNamePairs:
  714. tImPath, gtImPath = pair
  715. trainNames, gtNames = \
  716. createTrainImageAndTrainTruthFileNames(tImPath, gtImPath)
  717. trainImageFileNames.extend(trainNames)
  718. trainTruthFileNames.extend(gtNames)
  719. #test image section
  720. testImagePath = os.path.normpath("../DIBCO/2016/DIPCO2016_dataset/")
  721. testTruthPath = os.path.normpath("../DIBCO/2016/DIPCO2016_Dataset_GT/")
  722. testImageFileNames, testTruthFileNames = \
  723. createTrainImageAndTrainTruthFileNames(testImagePath, testTruthPath)
  724. return trainImageFileNames, trainTruthFileNames, \
  725. testImageFileNames, testTruthFileNames
  726. def createTrainImageAndTrainTruthFileNames(trainImagePath, trainTruthPath):
  727. trainImageFileNames = createTrainImageFileNamesList(trainImagePath)
  728. trainTruthFileNames = createTrainTruthFileNamesList(trainImageFileNames)
  729. trainImageFileNames = appendBMP(trainImageFileNames)
  730. # ~ print(trainImageFileNames)
  731. trainTruthFileNames = appendBMP(trainTruthFileNames)
  732. # ~ print(trainTruthFileNames)
  733. trainImageFileNames = prependPath(trainImagePath, trainImageFileNames)
  734. trainTruthFileNames = prependPath(trainTruthPath, trainTruthFileNames)
  735. return trainImageFileNames, trainTruthFileNames
  736. def createTrainImageFileNamesList(trainImagePath):
  737. # ~ trainFileNames = next(os.walk(trainImagePath))[2] #this is a clever hack
  738. # ~ trainFileNames = [name.replace(".bmp", "") for name in trainFileNames]
  739. trainFileNames = os.listdir(trainImagePath)
  740. print(trainFileNames)
  741. # ~ print("pausing...")
  742. # ~ a = input()
  743. return [name.replace(".bmp", "") for name in trainFileNames]
  744. #This makes a list with the same order of the names but with _gt apended.
  745. def createTrainTruthFileNamesList(originalNames):
  746. return [name + "_gt" for name in originalNames]
  747. def appendBMP(inputList):
  748. return [name + ".bmp" for name in inputList]
  749. def prependPath(myPath, nameList):
  750. return [os.path.join(myPath, name) for name in nameList]
  751. #I'm copying the code for jaccard similarity and dice from this MIT licenced source.
  752. #https://github.com/masyagin1998/robin
  753. #jaccard is size intersection of the sets / size union of the sets
  754. #Also, I'm going to try the smoothing values suggested in robin and here:
  755. #https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
  756. #They also suggest abs()
  757. def jaccardIndex(truth, prediction):
  758. #they are tensors?! not images?!?!?!?!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  759. # ~ truth = img_as_bool(np.asarray(truth)) #fail
  760. # ~ prediction = img_as_bool(np.asarray(prediction)) #fail
  761. smooth = GLOBAL_SMOOTH_JACCARD
  762. predictionFlat = backend.flatten(prediction)
  763. truthFlat = backend.flatten(truth)
  764. # ~ for i in range(len(predictionFlat)):
  765. # ~ if predictionFlat[i] >= 2:
  766. # ~ print("This was the bug. images have non- binary values...#############################################################################################################################################################")
  767. # ~ print("predictionFlat[" + str(i) + "]: " + str(predictionFlat[i]))
  768. # ~ intersectionImg = predictionFlat * truthFlat
  769. numberPixelsSame = backend.sum(truthFlat * predictionFlat)
  770. #I've found the function tensorflow.reduce_sum() which performs a sum by reduction
  771. #Is it better than backend.sum?? ##################################################################
  772. #the docs say it is equivalent except that numpy will change everything to int64
  773. return float((numberPixelsSame + smooth) / \
  774. ( \
  775. (backend.sum(predictionFlat) + backend.sum(truthFlat) - numberPixelsSame + smooth) \
  776. ))
  777. #loss function for use in training.
  778. def jaccardLoss(truth, prediction):
  779. return 1.0 - jaccardIndex(truth, prediction)
  780. #input must be binarized images consisting of values for pixels of either 1 or 0.
  781. def diceIndex(truth, prediction):
  782. smooth = GLOBAL_SMOOTH_DICE
  783. predictionFlat = backend.flatten(prediction)
  784. truthFlat = backend.flatten(truth)
  785. numberSamePixels = backend.sum(predictionFlat * truthFlat)
  786. return float((2 * numberSamePixels + smooth) \
  787. / (backend.sum(predictionFlat) + backend.sum(truthFlat) + smooth))
  788. #Loss function for use in training
  789. def diceLoss(truth, prediction):
  790. smooth = GLOBAL_SMOOTH_DICE
  791. return smooth - diceIndex(truth, prediction)
  792. #This creates a true positive mask from an inverted image (white pixels are text)
  793. def createMaskTruePositive(truth, prediction):
  794. pFlat = backend.flatten(prediction)
  795. pFlat = img_as_bool(pFlat)
  796. tFlat = backend.flatten(truth)
  797. tFlat = img_as_bool(tFlat)
  798. mask = pFlat * tFlat
  799. return np.reshape(mask, np.shape(prediction))
  800. #creaes true negative mask
  801. def createMaskTrueNegative(truth, prediction):
  802. pFlat = backend.flatten(prediction)
  803. pFlat = img_as_bool(pFlat)
  804. tFlat = backend.flatten(truth)
  805. tFlat = img_as_bool(tFlat)
  806. ##invert to make the 0 into ones
  807. pFlat = ~pFlat
  808. tFlat = ~tFlat
  809. ##then multiply
  810. mask = pFlat * tFlat
  811. return np.reshape(mask, np.shape(prediction))
  812. #Creates a mask for the False Positives
  813. def createMaskFalsePositive(truth, prediction):
  814. pFlat = backend.flatten(prediction)
  815. pFlat = img_as_bool(pFlat)
  816. tFlat = backend.flatten(truth)
  817. tFlat = img_as_bool(tFlat)
  818. #will I need these?
  819. # ~ falseArray = np.zeros(len(pFlat), dtype = bool)
  820. # ~ trueArray = np.ones(len(pFlat), dtype = bool)
  821. ##where is the prediction true 1, where the truth is 0 false?
  822. mask = np.where(pFlat > tFlat, True, False)
  823. # ~ mask = np.where(pFlat > tFlat, pFlat, ~pFlat)
  824. return np.reshape(mask, np.shape(prediction))
  825. #returns a mask of all the pixels that are not supposed to be false.
  826. def createMaskFalseNegative(truth, prediction):
  827. # ~ return createMaskFalsePositive(prediction, truth) #Just swapped the input!?? yes but bug happens. 3-1
  828. pFlat = backend.flatten(prediction)
  829. pFlat = img_as_bool(pFlat)
  830. tFlat = backend.flatten(truth)
  831. tFlat = img_as_bool(tFlat)
  832. mask = np.where(pFlat < tFlat, True, False)
  833. return np.reshape(mask, np.shape(prediction))
  834. #Color the prediction image with the pixels that are correct in red.
  835. def colorPredictionWithPredictionMask(predictionMask, originalPrediction, colorArray):
  836. prediction = img_as_bool(originalPrediction)
  837. prediction = np.where( predictionMask >= prediction, True, False ) #This makes the area to paint to white.
  838. prediction = img_as_float(prediction)
  839. predictionMask = np.squeeze(predictionMask, axis = 2)
  840. if IS_GLOBAL_PRINTING_ON:
  841. print("predictionMask shape: " + str(predictionMask.shape))
  842. rows, cols = predictionMask.shape
  843. colorMask = np.zeros((rows, cols, 3))
  844. colorMask[ predictionMask ] = colorArray
  845. predictionInColor = np.dstack((prediction, prediction, prediction))
  846. predictionColor_hsv = rgb2hsv(predictionInColor)
  847. colorMask_hsv = rgb2hsv(colorMask)
  848. alpha = 1.0
  849. predictionColor_hsv[..., 0] = colorMask_hsv[..., 0]
  850. predictionColor_hsv[..., 1] = colorMask_hsv[..., 1] * alpha
  851. outImg = hsv2rgb(predictionColor_hsv)
  852. return img_as_ubyte(outImg)
  853. def combinePredictionPicture(truePosMask, trueNegMask, falsePosMask, falseNegMask):
  854. redColor = [1, 0, 0]
  855. greenColor = [0, 1, 0]
  856. blueColor = [0, 0, 1]
  857. yellowColor = [1, 1, 0]
  858. truePosMask = np.squeeze(truePosMask, axis = 2)
  859. trueNegMask = np.squeeze(trueNegMask, axis = 2)
  860. falsePosMask = np.squeeze(falsePosMask, axis = 2)
  861. falseNegMask = np.squeeze(falseNegMask, axis = 2)
  862. ##make a numpy array of ONES, reshape to image size, then convert with imgtofloat
  863. # ~ print("truePosMask.shape: " + str(truePosMask.shape))
  864. rows, cols = truePosMask.shape
  865. # ~ prediction = np.ones( (rows, cols), dtype=bool )
  866. prediction = np.zeros( (rows, cols), dtype=bool )
  867. # ~ prediction = img_as_float(prediction)
  868. predictionRGB = np.dstack((prediction, prediction, prediction))
  869. # ~ predictionColor_hsv = rgb2hsv(predictionRGB)
  870. ## make the four color masks
  871. redColorMask = img_as_bool(falseNegMask)
  872. blueColorMask = img_as_bool(falsePosMask)
  873. greenColorMask = img_as_bool(truePosMask)
  874. yellowColorMask = img_as_bool(trueNegMask)
  875. predictionRGB[redColorMask, 0 ] = 1
  876. predictionRGB[greenColorMask, 1 ] = 1
  877. predictionRGB[blueColorMask, 2 ] = 1
  878. predictionRGB[yellowColorMask, 0 ] = 1
  879. predictionRGB[yellowColorMask, 1 ] = 1
  880. # ~ green_image = rgb_image.copy() # Make a copy
  881. # ~ green_image[:,:,0] = 0
  882. # ~ green_image[:,:,2] = 0
  883. # ~ alpha = 1.0
  884. # ~ predictionColor_hsv[..., 0] = redColorMask[..., 0]
  885. # ~ predictionColor_hsv[..., 1] = redColorMask[..., 1] * alpha
  886. # ~ predictionColor_hsv[..., 0] = blueColorMask[..., 0]
  887. # ~ predictionColor_hsv[..., 1] = blueColorMask[..., 1] * alpha
  888. # ~ predictionColor_hsv[..., 0] = greenColorMask[..., 0]
  889. # ~ predictionColor_hsv[..., 1] = greenColorMask[..., 1] * alpha
  890. # ~ predictionColor_hsv[..., 0] = yellowColorMask[..., 0]
  891. # ~ predictionColor_hsv[..., 1] = yellowColorMask[..., 1] * alpha
  892. # ~ outImg = hsv2rgb(predictionColor_hsv)
  893. # ~ return img_as_ubyte(outImg)
  894. return img_as_ubyte(predictionRGB)
  895. def makeThisColorMaskHsv(predictionMask, colorArray, rows, cols):
  896. thisColorMask = np.zeros((rows, cols, 3))
  897. thisColorMask[ predictionMask ] = colorArray
  898. return rgb2hsv(thisColorMask)
  899. def makeThisColorMaskRGB(predictionMask, colorArray, rows, cols):
  900. thisColorMask = np.zeros((rows, cols, 3))
  901. thisColorMask[ predictionMask ] = colorArray
  902. return thisColorMask
  903. if __name__ == '__main__':
  904. import sys
  905. sys.exit(main(sys.argv))