rezaunet.py 5.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. ###This code taken from reza azad and lightly changed to fit my project
  2. from keras.models import Model
  3. from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
  4. ###from keras.optimizers import Adam ### Old way
  5. from tensorflow.keras.optimizers import Adam ## my update to new way
  6. from keras.callbacks import ModelCheckpoint, LearningRateScheduler
  7. from keras import backend as K
  8. from keras.utils.vis_utils import plot_model as plot
  9. ##from keras.optimizers import SGD ## old
  10. from tensorflow.keras.optimizers import SGD ## new
  11. from keras.optimizers import *
  12. from keras.layers import *
  13. import numpy as np ## my addition
  14. #input_size is a tupple (h,w,channel)
  15. def BCDU_net_D3(input_size):
  16. N = input_size[0]
  17. inputs1 = Input(input_size) ## changed inputs to inputs1 here and everywhere else too
  18. sfilter = N / 4
  19. conv1 = Conv2D(sfilter, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs1) ## inputs1
  20. conv1 = Conv2D(sfilter, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
  21. pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
  22. conv2 = Conv2D(sfilter * 2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
  23. conv2 = Conv2D(sfilter * 2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
  24. pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
  25. conv3 = Conv2D(sfilter * 4, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
  26. conv3 = Conv2D(sfilter * 4, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
  27. drop3 = Dropout(0.5)(conv3)
  28. pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
  29. # D1
  30. conv4 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
  31. conv4_1 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
  32. drop4_1 = Dropout(0.5)(conv4_1)
  33. # D2
  34. conv4_2 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(drop4_1)
  35. conv4_2 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4_2)
  36. conv4_2 = Dropout(0.5)(conv4_2)
  37. # D3
  38. merge_dense = concatenate([conv4_2,drop4_1], axis = 3)
  39. conv4_3 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge_dense)
  40. conv4_3 = Conv2D(sfilter * 8, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4_3)
  41. drop4_3 = Dropout(0.5)(conv4_3)
  42. up6 = Conv2DTranspose(sfilter * 4, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(drop4_3)
  43. up6 = BatchNormalization(axis=3)(up6)
  44. up6 = Activation('relu')(up6)
  45. x1 = Reshape(target_shape=(1, np.int32(N/4), np.int32(N/4), sfilter * 4))(drop3)
  46. x2 = Reshape(target_shape=(1, np.int32(N/4), np.int32(N/4), sfilter * 4))(up6)
  47. merge6 = concatenate([x1,x2], axis = 1)
  48. merge6 = ConvLSTM2D(filters = sfilter * 2, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge6)
  49. conv6 = Conv2D(sfilter * 4, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
  50. conv6 = Conv2D(sfilter * 4, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
  51. up7 = Conv2DTranspose(sfilter * 2, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(conv6)
  52. up7 = BatchNormalization(axis=3)(up7)
  53. up7 = Activation('relu')(up7)
  54. x1 = Reshape(target_shape=(1, np.int32(N/2), np.int32(N/2), sfilter * 2))(conv2)
  55. x2 = Reshape(target_shape=(1, np.int32(N/2), np.int32(N/2), sfilter * 2))(up7)
  56. merge7 = concatenate([x1,x2], axis = 1)
  57. merge7 = ConvLSTM2D(filters = sfilter, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge7)
  58. conv7 = Conv2D(sfilter * 2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
  59. conv7 = Conv2D(sfilter * 2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
  60. up8 = Conv2DTranspose(sfilter, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(conv7)
  61. up8 = BatchNormalization(axis=3)(up8)
  62. up8 = Activation('relu')(up8)
  63. x1 = Reshape(target_shape=(1, N, N, sfilter))(conv1)
  64. x2 = Reshape(target_shape=(1, N, N, sfilter))(up8)
  65. merge8 = concatenate([x1,x2], axis = 1)
  66. merge8 = ConvLSTM2D(filters = sfilter / 2, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge8)
  67. conv8 = Conv2D(sfilter, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
  68. conv8 = Conv2D(sfilter, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
  69. conv8 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
  70. conv9 = Conv2D(1, 1, activation = 'sigmoid')(conv8)
  71. model = Model(inputs = inputs1, outputs = conv9) ## changed from input ouput to inputs1 outputs
  72. # ~ model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ["acc", jaccardIndex, diceIndex])
  73. model.compile(optimizer = Adam(), loss = 'binary_crossentropy', metrics = ["acc", jaccardIndex, diceIndex])
  74. return model