mnist-transfer-learning.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # mnist-transfer-learning.py
  5. #
  6. # 2022 Stephen Stengel <stephen.stengel@cwu.edu>
  7. #A test program to see how transfer learning works.
  8. #Tutorial sources:
  9. # ~ https://keras.io/examples/vision/image_classification_from_scratch/
  10. # ~ https://keras.io/guides/transfer_learning/
  11. # ~ https://towardsdatascience.com/transfer-learning-using-pre-trained-alexnet-model-and-fashion-mnist-43898c2966fb?gi=1f9cc1728578
  12. # ~ https://www.kaggle.com/code/muerbingsha/mnist-vgg19/notebook
  13. ## ! Training on mnist and then using those weights in a fashion_mnist
  14. ## ! model would be a way easier tutorial. or cifar10
  15. print("Running imports...")
  16. import os
  17. import time
  18. import numpy as np
  19. import tensorflow as tf
  20. from tensorflow import keras
  21. from tensorflow.keras import layers
  22. from keras.datasets import mnist #images of digits.
  23. from keras.datasets import fashion_mnist #images of clothes
  24. # ~ from keras.datasets import cifar10 #small images
  25. import matplotlib.pyplot as plt
  26. from tqdm import tqdm
  27. from skimage.transform import resize
  28. from skimage.color import gray2rgb
  29. from sklearn.model_selection import train_test_split
  30. print("Done!")
  31. IS_GLOBAL_PRINTING_ON = False
  32. HELPFILE_PATH = "helpfile"
  33. GLOBAL_DEFAULT_SLICE = 2000
  34. GLOBAL_DEFAULT_EPOCHS = 100
  35. def main(args):
  36. print("Hello!")
  37. #If I check args before imports, it will be faster.
  38. #Just save returns to their globals.
  39. sliceNum, epochsMnist = checkArgs(args)
  40. if IS_GLOBAL_PRINTING_ON:
  41. print("global printing true")
  42. print("sliceNum: %d" % sliceNum)
  43. print("epochsNum: %d" % epochsMnist)
  44. dFolder = "./digits/"
  45. os.system("mkdir -p " + dFolder)
  46. fFolder = "./fashion/"
  47. os.system("mkdir -p " + fFolder)
  48. # ~ preamble()
  49. # ~ epochsCat = 10
  50. # ~ xceptCatDog(epochsCat)
  51. xceptionOnMnistExample(epochsMnist, mnist, dFolder, sliceNum)
  52. xceptionOnMnistExample(epochsMnist, fashion_mnist, fFolder, sliceNum)
  53. # ~ xceptionOnMnistExample(epochsMnist, cifar10, cFolder, sliceNum)
  54. return 0
  55. #Check if user wants help
  56. def checkArgs(args):
  57. helpList = ["help", "-help", "--help", "-h", "--h", "wtf", "-wtf", "--wtf"]
  58. argLen = len(args)
  59. if argLen >= 1:
  60. for a in args:
  61. if str(a).lower() in helpList:
  62. printFile(HELPFILE_PATH)
  63. sys.exit(0)
  64. if argLen == 1:
  65. return GLOBAL_DEFAULT_SLICE, GLOBAL_DEFAULT_EPOCHS
  66. if argLen == 2:
  67. print("bad input")
  68. printFile(HELPFILE_PATH)
  69. sys.exit(-1)
  70. if argLen == 3:
  71. theSlice = int(sys.argv[1])
  72. theEpochs = int(sys.argv[2])
  73. return theSlice, theEpochs
  74. if argLen == 4:
  75. theSlice = int(sys.argv[1])
  76. theEpochs = int(sys.argv[2])
  77. global IS_GLOBAL_PRINTING_ON
  78. IS_GLOBAL_PRINTING_ON = True
  79. return theSlice, theEpochs
  80. if argLen > 4:
  81. print("bad input")
  82. printFile(HELPFILE_PATH)
  83. sys.exit(-1)
  84. #Prints a text file to screen
  85. def printFile(myFilePath):
  86. with open(myFilePath, "r") as helpfile:
  87. for line in helpfile:
  88. print(line, end = "")
  89. #Example of using xception network with imagenet weights to do transfer
  90. #learning on mnist. I'll cut down the classes in MNIST to just two to
  91. #more closely match the tutorial I'm following and to match the problem
  92. #that we will be solving for the project. HEHE
  93. def xceptionOnMnistExample(epochsMnist, myDataset, tmpFolder, sliceNum):
  94. print("Creating datasets...")
  95. train_ds, validation_ds, test_ds = readDataset(myDataset, sliceNum)
  96. print("Done!")
  97. if IS_GLOBAL_PRINTING_ON:
  98. print("Showing some images from the dataset...")
  99. printSomeOfDataset(train_ds)
  100. printSomeOfDataset(validation_ds)
  101. printSomeOfDataset(test_ds)
  102. print("Done!")
  103. print("Making the model...")
  104. model = makeTheModel()
  105. print("Done!")
  106. print("Training top layer...")
  107. epochs = epochsMnist
  108. if IS_GLOBAL_PRINTING_ON:
  109. print("train_ds: " + str(train_ds))
  110. myHistory = model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
  111. print("Done!")
  112. predictStuff(model, test_ds)
  113. performEvaluation(myHistory, tmpFolder, model, test_ds)
  114. print("HORAAAAAYYYYY")
  115. def predictStuff(model, the_ds ):
  116. predictions = model.predict(the_ds)
  117. score = predictions[0]
  118. if IS_GLOBAL_PRINTING_ON:
  119. print("predictions: " + str(predictions))
  120. print("score: " + str(score))
  121. #This is buggy. Maybe was made for a different type of score.
  122. print(
  123. "This image is %.2f percent thing and %.2f percent thing2."
  124. % (100 * (1 - score), 100 * score)
  125. )
  126. def makeTheModel():
  127. base_model = keras.applications.Xception(
  128. weights="imagenet", # Load weights pre-trained on ImageNet.
  129. input_shape=(150, 150, 3),
  130. include_top=False,
  131. ) # Do not include the ImageNet classifier at the top.
  132. # Freeze the base_model
  133. base_model.trainable = False
  134. # Create new layers on top of the old model
  135. inputs = keras.Input(shape=(150, 150, 3))
  136. x = layers.experimental.preprocessing.RandomFlip("horizontal")(inputs)
  137. # Pre-trained Xception weights requires that input be scaled
  138. # from (0, 255) to a range of (-1., +1.), the rescaling layer
  139. # outputs: `(inputs * scale) + offset`
  140. scale_layer = keras.layers.experimental.preprocessing.Rescaling(scale=1 / 127.5, offset=-1)
  141. x = scale_layer(x)
  142. # The base model contains batchnorm layers. We want to keep them in inference mode
  143. # when we unfreeze the base model for fine-tuning, so we make sure that the
  144. # base_model is running in inference mode here.
  145. x = base_model(x, training=False)
  146. x = keras.layers.GlobalAveragePooling2D()(x)
  147. x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
  148. outputs = keras.layers.Dense(1)(x)
  149. model = keras.Model(inputs, outputs)
  150. model.summary()
  151. model.compile(
  152. optimizer=keras.optimizers.Adam(),
  153. loss=keras.losses.BinaryCrossentropy(from_logits=True),
  154. metrics=[keras.metrics.BinaryAccuracy()],
  155. )
  156. return model
  157. #prints a few images from a loaded dataset.
  158. def printSomeOfDataset(myDataset):
  159. print("Looking at a few of the images...")
  160. plt.figure(figsize=(10, 10))
  161. for images, labels in myDataset.take(1):
  162. for i in tqdm(range(9)):
  163. ax = plt.subplot(3, 3, i + 1)
  164. plt.imshow(images[i].numpy().astype("float32"))
  165. plt.title(int(labels[i]))
  166. plt.axis("off")
  167. plt.show()
  168. #This reads the specified dataset into memory.
  169. #Try: mnist, fashion_mnist, cifar10
  170. def readDataset( myDataset, sliceNum ):
  171. print("Start reading the data ...")
  172. # Get the time
  173. StartTime = time.time()
  174. (x_train, y_train), (x_test, y_test) = myDataset.load_data()
  175. #I remember shuffling pairs of numpy arrays once using a random order list and zip() or something like that.
  176. #Can't shuffle here because the arrays are nonwriteable.
  177. # ~ rng = np.random.default_rng()
  178. # ~ rng.shuffle(x_train)
  179. # ~ rng.shuffle(y_train)
  180. # ~ rng.shuffle(x_test)
  181. # ~ rng.shuffle(y_test)
  182. #Full set is too much memory with these LISTS coming up below.
  183. #Currently slice of 30000 is a bit under 15GB
  184. cutNum = sliceNum
  185. x_train = x_train[:cutNum]
  186. y_train = y_train[:cutNum]
  187. x_test = x_test[:cutNum]
  188. y_test = y_test[:cutNum]
  189. #######
  190. #make validation set from training data.
  191. splitDecimal = 0.8
  192. x_train, x_val = valTrainSplit(x_train, splitDecimal)
  193. y_train, y_val = valTrainSplit(y_train, splitDecimal)
  194. if IS_GLOBAL_PRINTING_ON:
  195. print("SHAPES:...")
  196. print(x_train.shape,y_train.shape)
  197. print(x_test.shape, y_test.shape)
  198. print(x_val.shape, y_val.shape)
  199. #I just picked these things at random.
  200. ANKLE_BOOT = 9
  201. T_SHIRT = 0
  202. firstClass = ANKLE_BOOT
  203. secondClass = T_SHIRT
  204. #This is very slow because I put them in lists before converting to numpy arrays.
  205. #TODO: remember how to do that numpy extend thing. ##################################################!
  206. size = (150, 150)
  207. x_train, y_train = keepTwoClasses(x_train, y_train, firstClass, secondClass, size)
  208. x_test, y_test = keepTwoClasses(x_test, y_test, firstClass, secondClass, size)
  209. x_val, y_val = keepTwoClasses(x_val, y_val, firstClass, secondClass, size)
  210. if IS_GLOBAL_PRINTING_ON:
  211. print("SHAPES after selection of two classes...")
  212. print(x_train.shape,y_train.shape)
  213. print(x_test.shape, y_test.shape)
  214. print(x_val.shape, y_val.shape)
  215. ## ! NOTE THE NAME CHANGE ! ##
  216. # x_test, y_test becomes val_ds
  217. # x_val, y_val becomes realval_ds
  218. # I can change these names to be more consistent with the return order later.
  219. train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  220. val_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
  221. realval_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
  222. BATCH_SIZE = 64
  223. train_ds = train_ds.batch(BATCH_SIZE, drop_remainder=False)
  224. val_ds = val_ds.batch(BATCH_SIZE, drop_remainder=False)
  225. realval_ds = realval_ds.batch(BATCH_SIZE, drop_remainder=False)
  226. train_ds = train_ds.prefetch(buffer_size=32)
  227. val_ds = val_ds.prefetch(buffer_size=32)
  228. realval_ds = realval_ds.prefetch(buffer_size=32)
  229. #Get the time
  230. EndTime = time.time()
  231. print("Elapsed time to read dataset into memory: ", EndTime - StartTime)
  232. print("WORKSOFAR !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
  233. #realval_ds is the dataset for validation. val_ds is actually the test set. change names later.
  234. return train_ds, realval_ds, val_ds
  235. #Keeps only two classes out of a total dataset.
  236. #In need of optimization.
  237. def keepTwoClasses(x_train, y_train, firstClass, secondClass, size):
  238. newxtrain = []
  239. newytrain = []
  240. for i in tqdm(range(len(x_train))):
  241. if y_train[i] == firstClass or y_train[i] == secondClass:
  242. # ~ print("ankle boot lol")
  243. newxtrain.append(gray2rgb(resize(x_train[i], size)) )
  244. newytrain.append(y_train[i])
  245. x_train = np.asarray(newxtrain)
  246. y_train = np.asarray(newytrain)
  247. return x_train, y_train
  248. #Split a validation set from the training set.
  249. #There is already a test set created by the mnist loading function.
  250. def valTrainSplit(x_train, splitDecimal):
  251. sliceidx = int(len(x_train) * splitDecimal)
  252. x_val = x_train[sliceidx:]
  253. x_train = x_train[:sliceidx]
  254. return x_train, x_val
  255. #Example of using a pretrained network with weights, adding a bit to the
  256. #model, and training on a new dataset. This example loads weights from
  257. #imagenet and puts them in a xception network. Then it adds a few
  258. #layers, changes to create binary output, and runs against the
  259. #cats_vs_dogs dataset.
  260. def xceptCatDog(epochsCat):
  261. #Starting by importing the dataset manually because the first tutorial just doesn't compile.
  262. #Using this other tutorial: https://keras.io/examples/vision/image_classification_from_scratch/
  263. # ~ imagesTopFolderName = "PetImages"
  264. imagesTopFolderName = "shorter-pet-images"
  265. print("Cleaning out images that the tutorial doesn't like.")
  266. num_skipped = 0
  267. for folder_name in ("Cat", "Dog"):
  268. folder_path = os.path.join(imagesTopFolderName, folder_name)
  269. for fname in tqdm(os.listdir(folder_path)):
  270. fpath = os.path.join(folder_path, fname)
  271. try:
  272. fobj = open(fpath, "rb")
  273. is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
  274. finally:
  275. fobj.close()
  276. if not is_jfif:
  277. num_skipped += 1
  278. # Delete corrupted image
  279. os.remove(fpath)
  280. print("Deleted %d images" % num_skipped)
  281. print("Creating datasets...")
  282. image_size = (180, 180)
  283. batch_size = 32
  284. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  285. imagesTopFolderName,
  286. validation_split=0.2,
  287. subset="training",
  288. seed=1337,
  289. image_size=image_size,
  290. batch_size=batch_size,
  291. )
  292. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  293. imagesTopFolderName,
  294. validation_split=0.2,
  295. subset="validation",
  296. seed=1337,
  297. image_size=image_size,
  298. batch_size=batch_size,
  299. )
  300. print("Done!")
  301. print("train_ds: " + str(train_ds) + " ##################################################")
  302. # ~ print("Looking at a few of the images...")
  303. # ~ plt.figure(figsize=(10, 10))
  304. # ~ for images, labels in train_ds.take(1):
  305. # ~ for i in tqdm(range(9)):
  306. # ~ ax = plt.subplot(3, 3, i + 1)
  307. # ~ plt.imshow(images[i].numpy().astype("uint8"))
  308. # ~ plt.title(int(labels[i]))
  309. # ~ plt.axis("off")
  310. # ~ plt.show() #THis line is needed for the pictures to actually show.
  311. # ~ print("Done!")
  312. print("Viewing and Augmenting data...")
  313. #Need to use experimental tag because my conda instaled tensorflow 2.4
  314. # ~ data_augmentation = keras.Sequential(
  315. # ~ [
  316. # ~ layers.experimental.preprocessing.RandomFlip("horizontal"),
  317. # ~ layers.experimental.preprocessing.RandomRotation(0.1),
  318. # ~ ]
  319. # ~ )
  320. #comment to faster lol
  321. # ~ plt.figure(figsize=(10, 10))
  322. # ~ for images, _ in train_ds.take(1):
  323. # ~ for i in tqdm(range(9)):
  324. # ~ augmented_images = data_augmentation(images)
  325. # ~ ax = plt.subplot(3, 3, i + 1)
  326. # ~ plt.imshow(augmented_images[0].numpy().astype("uint8"))
  327. # ~ plt.axis("off")
  328. # ~ plt.show()
  329. print("Done!")
  330. ##### Back to the original tutorial ######
  331. print("Preprocessing...")
  332. # ~ train_ds = train_ds.prefetch(buffer_size=32)
  333. # ~ val_ds = val_ds.prefetch(buffer_size=32)
  334. size = (150, 150)
  335. print("train_ds before resize: " + str(train_ds))
  336. train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
  337. validation_ds = val_ds.map(lambda x, y: (tf.image.resize(x, size), y)) #RENAME val_ds
  338. ####### This makes the shapes wonky and causes crash.
  339. # ~ batch_size = 32
  340. # ~ train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
  341. # ~ validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
  342. #This simpler way works fine it seems. (I haven't gotten past one epoch yet. I need o pare down the data still.
  343. train_ds = train_ds.prefetch(buffer_size=32)
  344. validation_ds = validation_ds.prefetch(buffer_size=32)
  345. #test
  346. # ~ data_augmentation = keras.Sequential(
  347. # ~ [layers.experimental.preprocessing.RandomFlip("horizontal"), layers.experimental.preprocessing.RandomRotation(0.1),]
  348. # ~ )
  349. # ~ for images, labels in train_ds.take(1):
  350. # ~ plt.figure(figsize=(10, 10))
  351. # ~ first_image = images[0]
  352. # ~ for i in range(9):
  353. # ~ ax = plt.subplot(3, 3, i + 1)
  354. # ~ augmented_image = data_augmentation(
  355. # ~ tf.expand_dims(first_image, 0), training=True
  356. # ~ )
  357. # ~ plt.imshow(augmented_image[0].numpy().astype("int32"))
  358. # ~ plt.title(int(labels[0]))
  359. # ~ plt.axis("off")
  360. # ~ plt.show()
  361. print("Done!")
  362. print("Making the model...")
  363. base_model = keras.applications.Xception(
  364. weights="imagenet", # Load weights pre-trained on ImageNet.
  365. input_shape=(150, 150, 3),
  366. include_top=False,
  367. ) # Do not include the ImageNet classifier at the top.
  368. # Freeze the base_model
  369. base_model.trainable = False
  370. # Create new model on top
  371. inputs = keras.Input(shape=(150, 150, 3))
  372. # ~ x = data_augmentation(inputs) # Apply random data augmentation
  373. # ~ x = inputs
  374. x = layers.experimental.preprocessing.RandomFlip("horizontal")(inputs)
  375. # Pre-trained Xception weights requires that input be scaled
  376. # from (0, 255) to a range of (-1., +1.), the rescaling layer
  377. # outputs: `(inputs * scale) + offset`
  378. scale_layer = keras.layers.experimental.preprocessing.Rescaling(scale=1 / 127.5, offset=-1)
  379. x = scale_layer(x)
  380. # The base model contains batchnorm layers. We want to keep them in inference mode
  381. # when we unfreeze the base model for fine-tuning, so we make sure that the
  382. # base_model is running in inference mode here.
  383. x = base_model(x, training=False)
  384. x = keras.layers.GlobalAveragePooling2D()(x)
  385. x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
  386. outputs = keras.layers.Dense(1)(x)
  387. model = keras.Model(inputs, outputs)
  388. model.summary()
  389. print("Done!")
  390. #train
  391. print("Training top layer...")
  392. model.compile(
  393. optimizer=keras.optimizers.Adam(),
  394. loss=keras.losses.BinaryCrossentropy(from_logits=True),
  395. metrics=[keras.metrics.BinaryAccuracy()],
  396. )
  397. # ~ epochs = 20
  398. epochs = epochsCat
  399. print("train_ds: " + str(train_ds))
  400. # ~ print("train_ds shape: " + str(train_ds.shape))
  401. model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
  402. print("Done!")
  403. image_size = (150, 150)
  404. img = keras.preprocessing.image.load_img(
  405. "PetImages/Cat/6779.jpg", target_size=image_size
  406. )
  407. img_array = keras.preprocessing.image.img_to_array(img)
  408. img_array = tf.expand_dims(img_array, 0) # Create batch axis
  409. predictions = model.predict(img_array)
  410. score = predictions[0]
  411. print("predictions: " + str(predictions))
  412. print("score: " + str(score))
  413. print(
  414. "This image is %.2f percent cat and %.2f percent dog."
  415. % (100 * (1 - score), 100 * score)
  416. )
  417. #an xception model from a tutorial-- UNUSED
  418. def make_modelUNUSED(input_shape, num_classes):
  419. inputs = keras.Input(shape=input_shape)
  420. # Image augmentation block
  421. # ~ x = data_augmentation(inputs) #BUGGY
  422. #convert to functional or whatever.
  423. x = layers.experimental.preprocessing.RandomFlip("horizontal")(inputs)
  424. # ~ x = layers.experimental.preprocessing.RandomRotation(0.1)(x) #Buggy
  425. # Entry block
  426. x = layers.experimental.preprocessing.Rescaling(1.0 / 255)(x)
  427. x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
  428. x = layers.BatchNormalization()(x)
  429. x = layers.Activation("relu")(x)
  430. x = layers.Conv2D(64, 3, padding="same")(x)
  431. x = layers.BatchNormalization()(x)
  432. x = layers.Activation("relu")(x)
  433. previous_block_activation = x # Set aside residual
  434. for size in [128, 256, 512, 728]:
  435. x = layers.Activation("relu")(x)
  436. x = layers.SeparableConv2D(size, 3, padding="same")(x)
  437. x = layers.BatchNormalization()(x)
  438. x = layers.Activation("relu")(x)
  439. x = layers.SeparableConv2D(size, 3, padding="same")(x)
  440. x = layers.BatchNormalization()(x)
  441. x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
  442. # Project residual
  443. residual = layers.Conv2D(size, 1, strides=2, padding="same")(
  444. previous_block_activation
  445. )
  446. x = layers.add([x, residual]) # Add back residual
  447. previous_block_activation = x # Set aside next residual
  448. x = layers.SeparableConv2D(1024, 3, padding="same")(x)
  449. x = layers.BatchNormalization()(x)
  450. x = layers.Activation("relu")(x)
  451. x = layers.GlobalAveragePooling2D()(x)
  452. if num_classes == 2:
  453. activation = "sigmoid"
  454. units = 1
  455. else:
  456. activation = "softmax"
  457. units = num_classes
  458. x = layers.Dropout(0.5)(x)
  459. outputs = layers.Dense(units, activation=activation)(x)
  460. return keras.Model(inputs, outputs)
  461. #Script containing stuff from a tutorial
  462. def preamble():
  463. print("Hello lol!")
  464. layer = keras.layers.Dense(3)
  465. layer.build((None, 4)) # Create the weights
  466. print("weights:", len(layer.weights))
  467. print("trainable_weights:", len(layer.trainable_weights))
  468. print("non_trainable_weights:", len(layer.non_trainable_weights))
  469. layer = keras.layers.BatchNormalization()
  470. layer.build((None, 4)) # Create the weights
  471. print("weights:", len(layer.weights))
  472. print("trainable_weights:", len(layer.trainable_weights))
  473. print("non_trainable_weights:", len(layer.non_trainable_weights))
  474. layer = keras.layers.Dense(3)
  475. layer.build((None, 4)) # Create the weights
  476. layer.trainable = False # Freeze the layer
  477. print("weights:", len(layer.weights))
  478. print("trainable_weights:", len(layer.trainable_weights))
  479. print("non_trainable_weights:", len(layer.non_trainable_weights))
  480. print("\n\n\n\n\n#####################")
  481. # Make a model with 2 layers
  482. layer1 = keras.layers.Dense(3, activation="relu")
  483. layer2 = keras.layers.Dense(3, activation="sigmoid")
  484. model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
  485. # Freeze the first layer
  486. layer1.trainable = False
  487. # Keep a copy of the weights of layer1 for later reference
  488. initial_layer1_weights_values = layer1.get_weights()
  489. # Train the model
  490. model.compile(optimizer="adam", loss="mse")
  491. model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
  492. # Check that the weights of layer1 have not changed during training
  493. final_layer1_weights_values = layer1.get_weights()
  494. np.testing.assert_allclose(
  495. initial_layer1_weights_values[0], final_layer1_weights_values[0]
  496. )
  497. np.testing.assert_allclose(
  498. initial_layer1_weights_values[1], final_layer1_weights_values[1]
  499. )
  500. print("\n\n\n\n###############\n\n\n")
  501. print("Asserting that trainable status propagates recursively.")
  502. inner_model = keras.Sequential(
  503. [
  504. keras.Input(shape=(3,)),
  505. keras.layers.Dense(3, activation="relu"),
  506. keras.layers.Dense(3, activation="relu"),
  507. ]
  508. )
  509. model = keras.Sequential(
  510. [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
  511. )
  512. model.trainable = False # Freeze the outer model
  513. assert inner_model.trainable == False # All layers in `model` are now frozen
  514. assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
  515. print("\n\n\n\n#########################################################\n\n\n")
  516. print("Now for the actual example using pretrained weights.\n\n")
  517. base_model = keras.applications.Xception(
  518. weights='imagenet', # Load weights pre-trained on ImageNet.
  519. input_shape=(150, 150, 3),
  520. include_top=False) # Do not include the ImageNet classifier at the top.
  521. base_model.trainable = False
  522. #New model to go on top of base_model
  523. inputs = keras.Input(shape=(150, 150, 3))
  524. # We make sure that the base_model is running in inference mode here,
  525. # by passing `training=False`. This is important for fine-tuning, as you will
  526. # learn in a few paragraphs.
  527. x = base_model(inputs, training=False)
  528. # Convert features of shape `base_model.output_shape[1:]` to vectors
  529. x = keras.layers.GlobalAveragePooling2D()(x)
  530. # A Dense classifier with a single unit (binary classification)
  531. outputs = keras.layers.Dense(1)(x)
  532. model = keras.Model(inputs, outputs)
  533. model.compile(
  534. optimizer=keras.optimizers.Adam(),
  535. loss=keras.losses.BinaryCrossentropy(from_logits=True),
  536. metrics=[keras.metrics.BinaryAccuracy()])
  537. #This part needs a new dataset loaded. I'll look at the rest of the tutorial first.
  538. # ~ model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
  539. print("\n\n\n\n#########################################################")
  540. print("Fine tuning...")
  541. # Unfreeze the base model
  542. base_model.trainable = True
  543. # It's important to recompile your model after you make any changes
  544. # to the `trainable` attribute of any inner layer, so that your changes
  545. # are take into account
  546. model.compile(
  547. optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
  548. loss=keras.losses.BinaryCrossentropy(from_logits=True),
  549. metrics=[keras.metrics.BinaryAccuracy()])
  550. # Train end-to-end. Be careful to stop before you overfit!
  551. # ~ model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
  552. #BIG NOTE: Whenever you change the trainable or other stuff in the
  553. #model, you must re-compile it otherwise it wont do much.
  554. print("\n\n\n\n#########################################################")
  555. print("If using a custom fit method instead of .fit()...")
  556. # Create base model
  557. base_model = keras.applications.Xception(
  558. weights='imagenet',
  559. input_shape=(150, 150, 3),
  560. include_top=False)
  561. # Freeze base model
  562. base_model.trainable = False
  563. # Create new model on top.
  564. inputs = keras.Input(shape=(150, 150, 3))
  565. x = base_model(inputs, training=False)
  566. x = keras.layers.GlobalAveragePooling2D()(x)
  567. outputs = keras.layers.Dense(1)(x)
  568. model = keras.Model(inputs, outputs)
  569. loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
  570. optimizer = keras.optimizers.Adam()
  571. # Iterate over the batches of a dataset.
  572. # ~ for inputs, targets in new_dataset:
  573. # ~ # Open a GradientTape.
  574. # ~ with tf.GradientTape() as tape:
  575. # ~ # Forward pass.
  576. # ~ predictions = model(inputs)
  577. # ~ # Compute the loss value for this batch.
  578. # ~ loss_value = loss_fn(targets, predictions)
  579. # Get gradients of loss wrt the *trainable* weights.
  580. # ~ gradients = tape.gradient(loss_value, model.trainable_weights)
  581. # ~ # Update the weights of the model.
  582. # ~ optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  583. print("\n\n\n\n#########################################################")
  584. #my evaluation grapher thingy. Saves graphs to file.
  585. def performEvaluation(history, tmpFolder, model, test_ds):
  586. print("Performing evaluation...")
  587. scores = model.evaluate(test_ds)
  588. if IS_GLOBAL_PRINTING_ON:
  589. print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
  590. print("history...")
  591. print(history)
  592. print("history.history...")
  593. print(history.history)
  594. accuracy = history.history["binary_accuracy"]
  595. val_accuracy = history.history["val_binary_accuracy"]
  596. loss = history.history["loss"]
  597. val_loss = history.history["val_loss"]
  598. epochs = range(1, len(accuracy) + 1)
  599. plt.plot(epochs, accuracy, "o", label="Training accuracy")
  600. plt.plot(epochs, val_accuracy, "^", label="Validation accuracy")
  601. plt.title("Training and validation accuracy")
  602. plt.legend()
  603. plt.savefig(tmpFolder + "trainvalacc.png")
  604. plt.clf()
  605. plt.plot(epochs, loss, "o", label="Training loss")
  606. plt.plot(epochs, val_loss, "^", label="Validation loss")
  607. plt.title("Training and validation loss")
  608. plt.legend()
  609. plt.savefig(tmpFolder + "trainvalloss.png")
  610. plt.clf()
  611. if __name__ == '__main__':
  612. import sys
  613. sys.exit(main(sys.argv))