码迷,mamicode.com
首页 > Web开发 > 详细

U-Net网络实现医学图像分割

时间:2020-05-31 11:12:35      阅读:173      评论:0      收藏:0      [点我收藏+]

标签:pat   conv2   告诉   check   ilo   pycha   data   begin   work   

U-Net网络是典型的编解码网络,常用于图像分割

 

首先,准备数据

 

第一步,将nii格式医学图像转换为png格式,按给定窗位窗宽截断

  1 import numpy as np
  2 import cv2 as cv
  3 import nibabel as nib
  4 import os
  5 from PIL import Image
  6 import imageio
  7 
  8 
  9 def transform_ctdata(image, windowWidth, windowCenter, normal=False):
 10     """
 11     注意,这个函数的self.image一定得是float类型的,否则就无效!
 12     return: trucated image according to window center and window width
 13     """
 14     minWindow = float(windowCenter) - 0.5 * float(windowWidth)
 15     newimg = (image - minWindow) / float(windowWidth)
 16     newimg[newimg < 0] = 0
 17     newimg[newimg > 1] = 1
 18     if not normal:
 19         newimg = (newimg * 255).astype(uint8)
 20     return newimg
 21 
 22 
 23 # nii文件存放路径
 24 train_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/train"
 25 label_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/label"
 26 # slice存放路径
 27 train_save_path = ./data_raw_slice_tumour/train
 28 label_save_path = ./data_raw_slice_tumour/label
 29 if not os.path.isdir(train_save_path):
 30     os.makedirs(train_save_path)
 31 if not os.path.isdir(label_save_path):
 32     os.makedirs(label_save_path)
 33 
 34 
 35 # 准备导入训练图像以及标签,并对图像和标签进行排序
 36 train_images = os.listdir(train_path)
 37 train_images.sort(key=lambda x: int(x.split(-)[1].split(.)[0]))
 38 label_images = os.listdir(label_path)
 39 label_images.sort(key=lambda x: int(x.split(-)[1].split(.)[0]))
 40 
 41 
 42 def create_train_label_slice():
 43     print(- * 30)
 44     print("同时分解volume与segmentation文件")
 45     for i in range(len(train_images)):
 46         train_image = nib.load(train_path + / + train_images[i])
 47         label_image = nib.load(label_path + / + label_images[i])
 48         # 获取每一个nii文件的行、列、切片数
 49         height, width, slice = train_image.shape
 50         print("" + str(i) + "个dir", "(", height, width, slice, ")")
 51         # 保存切片的小子文件夹序号,0,1,2等
 52         slice_save_path = train_images[i].split(-)[1].split(.)[0]
 53 
 54         train_slice_path = train_save_path + / + slice_save_path
 55         label_slice_path = label_save_path + / + slice_save_path
 56 
 57         # if not os.path.isdir(train_slice_path):
 58         #     os.makedirs(train_slice_path)
 59         # if not os.path.isdir(label_slice_path):
 60         #     os.makedirs(label_slice_path)
 61 
 62         img_fdata = label_image.get_fdata()
 63         for j in range(slice):
 64             train_img = train_image.dataobj[:, :, j]
 65             label_img = img_fdata[:, :, j]
 66 
 67             label_img[label_img == 1] = 0
 68             label_img[label_img == 2] = 1
 69 
 70             white_pixel = label_img == 1
 71             white_pixel_num = len(label_img[white_pixel])
 72 
 73             # # 判断是否为全黑的标签,这样没有意义,剔除
 74             # if label_img.max() != 0:
 75 
 76             # 肿瘤标签像素点数量应该大于50,才算作有效数据
 77             if white_pixel_num >= 50:
 78                 set_slice = np.array(train_img).copy()
 79                 set_slice = set_slice.astype("float32")
 80                 # 训练用的窗位窗宽
 81                 # set_slice = transform_ctdata(set_slice, 350, 25)
 82                 # 知乎上参考,肝脏是40~60
 83                 set_slice = transform_ctdata(set_slice, 200, 30)
 84 
 85                 # set_slice = set_slice.astype("float32")
 86                 # mean = set_slice.mean()
 87                 # std = np.std(set_slice)
 88                 # set_slice -= mean
 89                 # set_slice /= std
 90                 # set_slice = (set_slice - set_slice.min()) / (set_slice.max() - set_slice.min())
 91                 # set_slice *= 255
 92                 # # set_slice = transform_ctdata(set_slice, 250, 125)
 93                 # set_slice = set_slice.astype("uint8")
 94 
 95                 # 中值滤波,去除椒盐噪声
 96                 set_slice = cv.medianBlur(set_slice, 3)
 97 
 98                 if not os.path.isdir(train_slice_path):
 99                     os.makedirs(train_slice_path)
100                 if not os.path.isdir(label_slice_path):
101                     os.makedirs(label_slice_path)
102 
103                 # 加入直方图均衡处理
104                 # set_slice = cv.equalizeHist(set_slice)
105                 cv.imwrite(train_slice_path + / + str(j) + .png, set_slice)
106                 label_img = Image.fromarray(np.uint8(label_img * 255))
107                 imageio.imwrite(label_slice_path + / + str(j) + .png, label_img)
108             else:
109                 pass
110     print("Generating train data set done!")
111 
112 
113 if __name__ == "__main__":
114     create_train_label_slice()

 

第二步,计算每一位病人中切片的均值,为降低病人间切片亮度差异作准备

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 train_raw_path = "./data_raw_slice_tumour/train"
 7 label_raw_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 def count_dir_mean_fun():
11     dirs = os.listdir(train_raw_path)
12     dirs.sort(key=lambda x: int(x))
13 
14     mean1 = []
15     mean2 = 0.0
16     mean3 = []
17 
18     for dir in dirs:
19         train_dir_path = os.path.join(train_raw_path, dir)
20         images = os.listdir(train_dir_path)
21         images.sort(key=lambda x: int(x.split(.)[0]))
22 
23         image_num = len(images)
24         mean = 0.0
25         for name in images:
26             image_path = os.path.join(train_dir_path, name)
27             image = cv.imread(image_path, 0)
28             image = image.astype("float32")
29             # image /= 255
30 
31             black_pixel = image <= 5
32             black_num = len(image[black_pixel])
33 
34             # mean += image.mean()
35             mean += image.mean() * 512 * 512 / (512 * 512 - black_num)
36 
37         mean /= image_num
38         mean1.append(mean)
39     mean2 = sum(mean1) / len(mean1)
40     mean3[:] = [x - mean2 for x in mean1]
41     print(mean1)
42     print(mean2)
43     print(mean3)
44     mean1 = np.array(mean1)
45     mean2 = np.array(mean2)
46     mean3 = np.array(mean3)
47     # mean_save_path = "./mean_array"
48     mean_save_path = "./mean_array_without_background"
49     if not os.path.isdir(mean_save_path):
50         os.makedirs(mean_save_path)
51     np.save(mean_save_path + / + "mean1.npy", mean1)
52     np.save(mean_save_path + / + "mean2.npy", mean2)
53     np.save(mean_save_path + / + "mean3.npy", mean3)
54 
55 
56 if __name__ == "__main__":
57     count_dir_mean_fun()

 

第三步,降低病人间亮度差异

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 train_raw_path = "./data_raw_slice_tumour/train"
 7 label_raw_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 dirs = os.listdir(train_raw_path)
11 dirs.sort(key=lambda x: int(x))
12 
13 train_save_path = "./data_slice_tumour_modify_brightness/train"
14 label_save_path = "./data_slice_tumour_modify_brightness/label"
15 
16 
17 # mean1 = np.load("./mean_array/mean1.npy")
18 # mean2 = np.load("./mean_array/mean2.npy")
19 mean1 = np.load("./mean_array_without_background/mean1.npy")
20 mean2 = np.load("./mean_array_without_background/mean2.npy")
21 for i in range(len(dirs)):
22     dir_name = dirs[i]
23     mean_sub = mean1[i]
24     mean_add = mean2
25 
26     train_dir_path = os.path.join(train_raw_path, dir_name)
27     images = os.listdir(train_dir_path)
28     images.sort(key=lambda x: int(x.split(.)[0]))
29     for name in images:
30         image_path = os.path.join(train_dir_path, name)
31         image = cv.imread(image_path, 0)
32         image = image.astype("float32")
33         # image /= 255
34         image -= mean_sub
35         image += mean_add
36         # image *= 255
37         # image = image.astype("uint8")
38 
39         save_path = os.path.join(train_save_path, dir_name)
40         if not os.path.isdir(save_path):
41             os.makedirs(save_path)
42         cv.imwrite(save_path + / + name, image)
43     print("完成第{}个dir".format(i))

 

第四步,对切片进行裁剪,减少不必要背景

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 raw_slice_train_path = "./data_slice_tumour_modify_brightness/train"
 7 raw_slice_label_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 train_clip_save_path = "./data_cv_clip_whole/train"
11 label_clip_save_path = "./data_cv_clip_whole/label"
12 
13 
14 dirs = os.listdir(raw_slice_train_path)
15 dirs.sort(key=lambda x: int(x))
16 
17 j = 0
18 for dir in dirs:
19     train_dir_path = os.path.join(raw_slice_train_path, dir)
20     label_dir_path = os.path.join(raw_slice_label_path, dir)
21 
22     names = os.listdir(train_dir_path)
23     names.sort(key=lambda x: int(x.split(.)[0]))
24 
25     i = 0
26     for name in names:
27         train_img_path = os.path.join(train_dir_path, name)
28         label_img_path = os.path.join(label_dir_path, name)
29 
30         train_img = cv.imread(train_img_path, 0)
31         label_img = cv.imread(label_img_path, 0)
32 
33         train_clip_img = train_img[0:400, 50:450]
34         label_clip_img = label_img[0:400, 50:450]
35         label_clip_img[label_clip_img == 255] = 255
36         label_clip_img[label_clip_img != 255] = 0
37 
38         if label_clip_img.max() == 0:
39             continue
40 
41         train_save_path = os.path.join(train_clip_save_path, dir)
42         label_save_path = os.path.join(label_clip_save_path, dir)
43 
44         if not os.path.isdir(train_save_path):
45             os.makedirs(train_save_path)
46         if not os.path.isdir(label_save_path):
47             os.makedirs(label_save_path)
48 
49         train_save_name = os.path.join(train_save_path, str(i) + ".png")
50         label_save_name = os.path.join(label_save_path, str(i) + ".png")
51 
52         cv.imwrite(train_save_name, train_clip_img)
53         cv.imwrite(label_save_name, label_clip_img)
54         i += 1
55     j += 1
56     print("完成第{}个dir".format(j))

 

第五步,以病人为单位,制作切片npy文件,用来上传服务器使用

 

 1 """
 2 本程序制作上传服务器的肝脏切片数据集
 3 一张张切片,按照dir分布
 4 """
 5 import os
 6 import cv2 as cv
 7 import numpy as np
 8 
 9 
10 # 处理过床位窗宽的train图像
11 train_png_path = "./data_cv_clip_whole/train"
12 # label标签
13 label_png_path = "./data_cv_clip_whole/label"
14 train_dir_npy_save_path = "./data_dir_npy/train"
15 label_dir_npy_save_path = "./data_dir_npy/label"
16 if not os.path.isdir(train_dir_npy_save_path):
17     os.makedirs(train_dir_npy_save_path)
18 if not os.path.isdir(label_dir_npy_save_path):
19     os.makedirs(label_dir_npy_save_path)
20 
21 train_dirs = os.listdir(train_png_path)
22 label_dirs = os.listdir(label_png_path)
23 train_dirs.sort(key=lambda x: int(x))
24 label_dirs.sort(key=lambda x: int(x))
25 
26 j = 0
27 for dir in train_dirs:
28     train_dir_path = os.path.join(train_png_path, dir)
29     label_dir_path = os.path.join(label_png_path, dir)
30 
31     dir_length = len(os.listdir(train_dir_path))
32     train_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
33     label_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
34 
35     train_imgs = os.listdir(train_dir_path)
36     label_imgs = os.listdir(label_dir_path)
37     train_imgs.sort(key=lambda x: int(x.split(.)[0]))
38     label_imgs.sort(key=lambda x: int(x.split(.)[0]))
39 
40     i = 0
41     for img in train_imgs:
42         train_img_path = os.path.join(train_dir_path, img)
43         label_img_path = os.path.join(label_dir_path, img)
44         train_img = cv.imread(train_img_path, 0)
45         label_img = cv.imread(label_img_path, 0)
46 
47         # cv.imshow("train", train_img)
48         # cv.imshow("label", label_img)
49         # cv.waitKey(0)
50         # cv.destroyAllWindows()
51 
52         train_img = np.reshape(train_img, (400, 400, 1))
53         label_img = np.reshape(label_img, (400, 400, 1))
54         train_dir_npy[i] = train_img
55         label_dir_npy[i] = label_img
56 
57         i += 1
58 
59     np.save(train_dir_npy_save_path + "/" + str(j) + ".npy", train_dir_npy)
60     np.save(label_dir_npy_save_path + "/" + str(j) + ".npy", label_dir_npy)
61     j += 1
62     print("第{}个文件夹".format(j))

 

第六步,在服务器中拆解npy文件,以病人为单位保存在各自的dir中

 

 1 """
 2 本程序将服务器的dir---npy文件拆解成dir---png文件
 3 注意:这里标签是纯肝脏数据,肿瘤作为背景
 4 """
 5 import os
 6 import cv2 as cv
 7 import numpy as np
 8 
 9 
10 train_png_path = "./data_dir_png/train"
11 label_png_path = "./data_dir_png/label"
12 train_npy_path = "./data_dir_npy/train"
13 label_npy_path = "./data_dir_npy/label"
14 if not os.path.isdir(train_png_path):
15     os.makedirs(train_png_path)
16 if not os.path.isdir(label_png_path):
17     os.makedirs(label_png_path)
18 
19 train_npys = os.listdir(train_npy_path)
20 label_npys = os.listdir(label_npy_path)
21 train_npys.sort(key=lambda x: int(x.split(".")[0]))
22 label_npys.sort(key=lambda x: int(x.split(".")[0]))
23 
24 j = 0
25 for npy in train_npys:
26     npy_path1 = os.path.join(train_npy_path, npy)
27     npy_path2 = os.path.join(label_npy_path, npy)
28     train_npy = np.load(npy_path1)
29     label_npy = np.load(npy_path2)
30     for i in range(len(train_npy)):
31         train_img = train_npy[i]
32         label_img = label_npy[i]
33         train_img = np.reshape(train_img, (400, 400))
34         label_img = np.reshape(label_img, (400, 400))
35 
36         # cv.imshow("train", train_img)
37         # cv.imshow("label", label_img)
38         # cv.waitKey(0)
39         # cv.destroyAllWindows()
40         train_save_dir_path = os.path.join(train_png_path, str(j))
41         label_save_dir_path = os.path.join(label_png_path, str(j))
42         if not os.path.isdir(train_save_dir_path):
43             os.makedirs(train_save_dir_path)
44         if not os.path.isdir(label_save_dir_path):
45             os.makedirs(label_save_dir_path)
46         cv.imwrite(train_save_dir_path + "/" + str(i) + ".png", train_img)
47         cv.imwrite(label_save_dir_path + "/" + str(i) + ".png", label_img)
48     j += 1
49     print("完成第{}个dir".format(j))

 

然后,训练U-Net网络

 

U-Net网络程序

 

  1 import keras
  2 from keras.models import *
  3 from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Dropout
  4 from keras.optimizers import *
  5 
  6 from keras.layers import Concatenate
  7 
  8 from keras import backend as K
  9 
 10 from keras.callbacks import ModelCheckpoint
 11 from fit_generator import get_path_list, get_train_batch
 12 import matplotlib.pyplot as plt
 13 
 14 # 每次训练模型之前,需要修改的三个地方,训练数据地址、保存模型地址、保存训练曲线地址
 15 
 16 train_batch_size = 2
 17 epoch = 5
 18 img_size = 400
 19 
 20 data_train_path = "./data_dir_png/train"
 21 data_label_path = "./data_dir_png/label"
 22 
 23 
 24 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
 25 
 26 
 27 # 写一个LossHistory类,保存loss和acc
 28 class LossHistory(keras.callbacks.Callback):
 29    def on_train_begin(self, logs={}):
 30        self.losses = {batch: [], epoch:[]}
 31        self.accuracy = {batch: [], epoch:[]}
 32        self.val_loss = {batch: [], epoch:[]}
 33        self.val_acc = {batch: [], epoch:[]}
 34 
 35    def on_batch_end(self, batch, logs={}):
 36        self.losses[batch].append(logs.get(loss))
 37        self.accuracy[batch].append(logs.get(dice_coef))
 38        self.val_loss[batch].append(logs.get(val_loss))
 39        self.val_acc[batch].append(logs.get(val_acc))
 40 
 41    def on_epoch_end(self, batch, logs={}):
 42        self.losses[epoch].append(logs.get(loss))
 43        self.accuracy[epoch].append(logs.get(dice_coef))
 44        self.val_loss[epoch].append(logs.get(val_loss))
 45        self.val_acc[epoch].append(logs.get(val_acc))
 46 
 47    def loss_plot(self, loss_type):
 48        iters = range(len(self.losses[loss_type]))
 49        plt.figure(1)
 50        # acc
 51        plt.plot(iters, self.accuracy[loss_type], r, label=train dice)
 52        if loss_type == epoch:
 53            # val_acc
 54            plt.plot(iters, self.val_acc[loss_type], b, label=val acc)
 55        plt.grid(True)
 56        plt.xlabel(loss_type)
 57        plt.ylabel(dice)
 58        plt.legend(loc="best")
 59 #       plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_dice_curve.png‘)
 60        plt.savefig(./curve_figure/unet_tumour_dice.png)
 61        
 62        plt.figure(2)
 63        # loss
 64        plt.plot(iters, self.losses[loss_type], g, label=train loss)
 65        if loss_type == epoch:
 66            # val_loss
 67            plt.plot(iters, self.val_loss[loss_type], k, label=val loss)
 68        plt.grid(True)
 69        plt.xlabel(loss_type)
 70        plt.ylabel(loss)
 71        plt.legend(loc="best")
 72 #       plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_loss_curve.png‘)
 73        plt.savefig(./curve_figure/unet_tumour_loss.png)
 74        plt.show()
 75 
 76 
 77 
 78 
 79 def dice_coef(y_true, y_pred):
 80     smooth = 1.
 81     y_true_f = K.flatten(y_true)
 82     y_pred_f = K.flatten(y_pred)
 83     intersection = K.sum(y_true_f * y_pred_f)
 84     return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
 85 
 86 
 87 def dice_coef_loss(y_true, y_pred):
 88     return 1. - dice_coef(y_true, y_pred)
 89 
 90 
 91 def mycrossentropy(y_true, y_pred, e=0.1):
 92     nb_classes = 10
 93     loss1 = K.categorical_crossentropy(y_true, y_pred)
 94     loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
 95     return (1 - e) * loss1 + e * loss2
 96 
 97 
 98 class myUnet(object):
 99     def __init__(self, img_rows=img_size, img_cols=img_size):
100         self.img_rows = img_rows
101         self.img_cols = img_cols
102 
103     def BN_operation(self, input):
104         output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
105                                                                scale=True,
106                                                                beta_initializer=zeros, gamma_initializer=ones,
107                                                                moving_mean_initializer=zeros,
108                                                                moving_variance_initializer=ones,
109                                                                beta_regularizer=None,
110                                                                gamma_regularizer=None, beta_constraint=None,
111                                                                gamma_constraint=None)(input)
112         return output
113 
114     def get_unet(self):
115         inputs = Input((self.img_rows, self.img_cols, 1))
116 
117         conv1 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(inputs)
118         conv1 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv1)
119         pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
120         # BN
121         # pool1 = self.BN_operation(pool1)
122 
123         conv2 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool1)
124         conv2 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv2)
125         pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
126         # BN
127         # pool2 = self.BN_operation(pool2)
128 
129         conv3 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool2)
130         conv3 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv3)
131         pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
132         # BN
133         # pool3 = self.BN_operation(pool3)
134 
135         conv4 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool3)
136         conv4 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv4)
137         drop4 = Dropout(0.5)(conv4)
138         pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
139         # BN
140         # pool4 = self.BN_operation(pool4)
141 
142         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool4)
143 
144         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv5)
145         drop5 = Dropout(0.5)(conv5)
146         # BN
147         # drop5 = self.BN_operation(drop5)
148 
149         up6 = Conv2D(512, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
150             UpSampling2D(size=(2, 2))(drop5))
151         merge6 = Concatenate(axis=3)([drop4, up6])
152         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge6)
153         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv6)
154 
155         up7 = Conv2D(256, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
156             UpSampling2D(size=(2, 2))(conv6))
157         merge7 = Concatenate(axis=3)([conv3, up7])
158         conv7 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge7)
159         conv7 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv7)
160 
161         up8 = Conv2D(128, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
162             UpSampling2D(size=(2, 2))(conv7))
163         merge8 = Concatenate(axis=3)([conv2, up8])
164         conv8 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge8)
165         conv8 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv8)
166 
167         up9 = Conv2D(64, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
168             UpSampling2D(size=(2, 2))(conv8))
169         merge9 = Concatenate(axis=3)([conv1, up9])
170         conv9 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge9)
171         conv9 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv9)
172         conv9 = Conv2D(2, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv9)
173         conv10 = Conv2D(1, 1, activation=sigmoid)(conv9)
174 
175         model = Model(inputs=inputs, outputs=conv10)
176 
177         # 在这里可以自定义损失函数loss和准确率函数accuracy
178         # model.compile(optimizer=Adam(lr=1e-4), loss=‘binary_crossentropy‘, metrics=[‘accuracy‘])
179         model.compile(optimizer=Adam(lr=1e-4), loss=binary_crossentropy, metrics=[accuracy,
180                                                                                     dice_coef])
181         print(model compile)
182         return model
183 
184     def train(self):
185         print("loading data")
186 
187         print("loading data done")
188         model = self.get_unet()
189 #        model = load_model(‘./model/unet_tumour1.hdf5‘, custom_objects={‘dice_coef‘: dice_coef,‘dice_coef_loss‘: dice_coef_loss})
190         print("got unet")
191 
192         # 保存的是模型和权重
193         model_checkpoint = ModelCheckpoint(./model/unet_tumour1.hdf5, monitor=loss,verbose=1, save_best_only=True)
194         print(Fitting model...)
195 
196         # 创建一个实例history
197         history = LossHistory()
198         # 在callbacks中加入history最后才能绘制收敛曲线
199         model.fit_generator(
200             generator=get_train_batch(train_path_list, label_path_list, train_batch_size, img_size, img_size),
201             epochs=epoch, verbose=1,
202             steps_per_epoch=count//train_batch_size,
203             callbacks=[model_checkpoint, history],
204             workers=1)
205         # 绘制acc-loss曲线
206         history.loss_plot(epoch)
207 
208 
209 if __name__ == __main__:
210     myunet = myUnet()
211     myunet.train()

 

导入数据程序

 

  1 import numpy as np
  2 import cv2 as cv
  3 import os
  4 
  5 data_train_path = "./deform/train"
  6 data_label_path = "./deform/label"
  7 
  8 
  9 def get_path_list(data_train_path, data_label_path):
 10     dirs = os.listdir(data_train_path)
 11     dirs.sort(key=lambda x: int(x))
 12     
 13     count = 0
 14     train_path_list = []
 15     label_path_list = []
 16     for dir in dirs:
 17         train_dir_path = os.path.join(data_train_path, dir)
 18         label_dir_path = os.path.join(data_label_path, dir)            
 19         imgs = os.listdir(train_dir_path)
 20         imgs.sort(key=lambda x: int(x.split(.)[0]))
 21         count += len(imgs)
 22         for img in imgs:
 23             train_img_path = os.path.join(train_dir_path, img)
 24             label_img_path = os.path.join(label_dir_path, img)
 25             train_path_list.append(train_img_path)
 26             label_path_list.append(label_img_path)
 27     print("共有{}组训练数据".format(count))
 28     return train_path_list, label_path_list, count
 29 
 30 
 31 def get_train_img(paths, img_rows, img_cols):
 32     """
 33     参数:
 34         paths:要读取的图片路径列表
 35         img_rows:图片行
 36         img_cols:图片列
 37         color_type:图片颜色通道
 38     返回:
 39         imgs: 图片数组
 40     """
 41     # Load as grayscale
 42     imgs = []
 43     for path in paths:
 44         img = cv.imread(path, 0)
 45         # Reduce size
 46         resized = np.reshape(img, (img_rows, img_cols, 1))
 47         resized = resized.astype(float32)
 48         resized /= 255
 49 #        mean = resized.mean(axis=0)
 50 #        resized -= mean
 51         imgs.append(resized)
 52     imgs = np.array(imgs)
 53     return imgs
 54 
 55 
 56 def get_label_img(paths, img_rows, img_cols):
 57     """
 58     参数:
 59         paths:要读取的图片路径列表
 60         img_rows:图片行
 61         img_cols:图片列
 62         color_type:图片颜色通道
 63     返回:
 64         imgs: 图片数组
 65     """
 66     # Load as grayscale
 67     imgs = []
 68     for path in paths:
 69         img = cv.imread(path, 0)
 70         # Reduce size
 71         resized = np.reshape(img, (img_cols, img_rows, 1))
 72         resized = resized.astype(float32)
 73         resized /= 255
 74         imgs.append(resized)
 75     imgs = np.array(imgs)
 76     return imgs
 77 
 78 
 79 def get_train_batch(train, label, batch_size, img_w, img_h):
 80     """
 81     参数:
 82         X_train:所有图片路径列表
 83         y_train: 所有图片对应的标签列表
 84         batch_size:批次
 85         img_w:图片宽
 86         img_h:图片高
 87         color_type:图片类型
 88         is_argumentation:是否需要数据增强
 89     返回:
 90         一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
 91     """
 92     while 1:
 93         for i in range(0, len(train), batch_size):
 94             x = get_train_img(train[i:i+batch_size], img_w, img_h)
 95             y = get_label_img(label[i:i+batch_size], img_w, img_h)
 96             # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
 97             yield(np.array(x), np.array(y))
 98 
 99 
100 if __name__ == "__main__":
101     train_path_list, label_path_list = get_path_list(data_train_path, data_label_path)

 

U-Net网络实现医学图像分割

标签:pat   conv2   告诉   check   ilo   pycha   data   begin   work   

原文地址:https://www.cnblogs.com/jinyiyexingzc/p/12996391.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!