本篇博客为unet网络的学习笔记,本部分优化了通过labelme标定的json文件生成mask部分的代码,使用cv2.fillConvexPoly在生成非凸图形时会有一些问题,使用cv2.fillPoly函数能得到更好的结果,这两个函数的输入的第2个参数部分也有一定差别。
1.Unet网络结构 unet.py
import torch
from torch import nn
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self,in_ch,out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6=self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
return c10
2.训练代码train.py
# -*- coding: utf-8 -*-
import torch
from torchvision.transforms import transforms
from torch import nn, optim
from unet import Unet
import numpy as np
# from tqdm import tqdm
import os
import cv2
import json
import matplotlib.pyplot as plt
path = r'E:\datasets\24\2022-01-05'
# train_image_path = os.path.join(path, 'train')
# train_label_path = os.path.join(path, 'train_labels')
# test_image_path = os.path.join(path, 'test')
# test_label_path = os.path.join(path, 'test_labels')
#
# train_image = os.listdir(train_image_path)
# train_label = os.listdir(train_label_path)
# test_image = os.listdir(test_image_path)
# test_label = os.listdir(test_label_path)
PATH = './unet_model.pt'
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torchvision.transfoms.ToTensor [h,w,c]->[c,h,w]
x_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# mask 只需要转换为tensor
y_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
def train_model(model, criterion, optimizer, num_epochs=20):
best_model = model
min_loss = 1000
dir_path = r'E:\datasets\ear\*.png'
import glob
aa = glob.glob(dir_path)
ids = []
for file_path in aa:
id = os.path.basename(file_path).split('_')[0]
# id的最后一个字符不是z,则添加
if id[-1] != 'z':
ids.append(id)
ids = list(set(ids))
print(ids)
ids.sort()
print('ids after sort:', ids)
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
epoch_loss = 0
step = 0
for i, id in enumerate(ids):
step += 1
print('i,id:', i, id)
id_img_path = os.path.join(path, id + '_color_0.png')
json_0_path = os.path.join(path, id + '_color_0.json')
image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
# cv2.imshow('img', image)
# cv2.waitKey(0)
image1 = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image1 = cv2.resize(image1, (512, 512))
labelme_json_0 = json.load(open(json_0_path, encoding='utf-8'))
points_tr_0 = np.round(labelme_json_0['shapes'][1]['points']).astype(np.uint64)
# points_tr_2 = np.round(labelme_json_2['shapes'][1]['points']).astype(np.uint64)
mask = np.zeros((720, 1280, 1), dtype=np.uint8)
points = labelme_json_0['shapes'][0]['points']
points = np.array(points)
points = points.reshape(-1, 1, 2)
points = points.astype(np.int32)
cv2.fillConvexPoly(mask, points, (255,))
# cv2.imshow('mask', mask)
# cv2.waitKey(0)
label = mask
# label = cv2.imread(train_label_path + '/' + train_label[i], cv2.IMREAD_GRAYSCALE)
# label1 = cv2.resize(label, (512, 512))
inputs = x_transforms(image1).unsqueeze(0).to(device)
labels = y_transforms(label).unsqueeze(0).to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d, train_loss:%0.3f" % (step, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss / step))
if (epoch_loss / step) < min_loss:
min_loss = (epoch_loss / step)
best_model = model
torch.save(best_model.state_dict(), PATH)
return best_model
# 训练模型
def train():
model = Unet(3, 1).to(device)
# batch_size = 1
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
# train_dataset = TrainDataset("", "", transform=x_transforms, target_transform=y_transforms())
# dataloaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer)
def test():
model = Unet(3, 1)
model.load_state_dict(torch.load(PATH))
with torch.no_grad():
for i in range(1):
id_img_path = os.path.join(path, '20' + '_color_0.png')
json_0_path = os.path.join(path, '20' + '_color_0.json')
# id_img_path = r'E:\datasets\24\2022-01-20\001_color_0.png'
image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
# cv2.imshow('img', image)
# cv2.waitKey(0)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = cv2.imread(test_image_path + '/' + test_image[i], cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image, (512, 512))
print('image:', image.shape)
inputs = x_transforms(image).unsqueeze(0)
print('inputs:', inputs.shape)
y = model(inputs)
y = y.squeeze(0)
y = y.permute(1, 2, 0)
y = torch.sigmoid(y).numpy()
# print(y.min())
# print(y.max())
y = (y * 255).astype(np.uint8)
# 获取mask
labelme_json_0 = json.load(open(json_0_path, encoding='utf-8'))
# points_tr_0 = np.round(labelme_json_0['shapes'][1]['points']).astype(np.uint64)
# points_tr_2 = np.round(labelme_json_2['shapes'][1]['points']).astype(np.uint64)
mask = np.zeros((720, 1280, 1), dtype=np.uint8)
points = labelme_json_0['shapes'][0]['points']
points = np.array(points)
print('points.shape', points.shape)
print(points)
# points = points.reshape(-1, 1, 2)
points = points.astype(np.int32)
# cv2.fillConvexPoly(mask, points, (255,))
cv2.fillPoly(mask, [points], (255))
cv2.imshow('label', mask)
# print('mask', mask)
print('mask.sum()', mask.sum()/255)
print(mask)
return y
if __name__ == '__main__':
print("开始训练")
# train()
print("训练完成,保存模型")
print("-" * 20)
print("开始预测")
y = test()
y = cv2.resize(y, (1280, 720))
threshold = (y.min() + y.max()) / 2
threshold = y.min() + 0.2 * (y.max() - y.min())
y[y > threshold] = 255
y[y < threshold] = 0
id_img_path = os.path.join(path, '20' + '_color_0.png')
# id_img_path = r'E:\datasets\24\2022-01-20\001_color_0.png'
image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
image[y == 0, :] = 0
cv2.imshow('image', image)
cv2.imshow('tt', y)
cv2.waitKey(0)
3.测试代码test.py
# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from unet import Unet
path = r'E:\datasets\24\2022-01-05'
test_image_path = os.path.join(path, 'test')
test_label_path = os.path.join(path, 'test_labels')
x_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
mopel_path = 'unet_model.pt'
def test(test_image_path, model_path):
test_image = os.listdir(test_image_path)
test_label = os.listdir(test_label_path)
# print(test_image)
# print(test_label)
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Unet(3, 1).to(device)
model.load_state_dict(torch.load(model_path))
# print(model)
with torch.no_grad():
# for i in range(len(train_image)):
for i in range(5):
image = cv2.imread(test_image_path + '/' + test_image[i], cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image, (512, 512))
print('image:', image.shape)
inputs = x_transforms(image).unsqueeze(0).to(device)
print('inputs.shape:', inputs.shape)
import time
print(time.time())
y = model(inputs)
print(time.time())
y = y.squeeze(0)
y = y.permute(1, 2, 0)
# print('min,max', y.min(), y.max())
y = torch.sigmoid(y)
# print('min,max', y.min()*255, y.max()*255)
y = y.cpu().numpy()
y = (y * 255).astype(np.uint8)
# print('min,max', y.min(), y.max())
# my remote ubuntu desktop, opencv-python cannot exit correctly.
# cv2.imshow('y', y)
# cv2.waitKey(3000)
# cv2.destroyAllWindows()
# plt.imshow(y, cmap=plt.get_cmap('gray'))
y = cv2.cvtColor(y, cv2.COLOR_GRAY2RGB)
threshold = (y.min() + y.max()) / 2
y[y > threshold] = 255
y[y < threshold] = 0
# plt.figure()
plt.imshow(y)
plt.show()
# plt.ion()
# plt.pause(4)
# plt.close()
return y
if __name__ == '__main__':
print("开始预测")
y = test(test_image_path, mopel_path)
threshold = (y.min() + y.max()) / 2
y[y > threshold] = 255
y[y < threshold] = 0
plt.imshow(y)
plt.show()