抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

最近在读 U-Net 论文时,网上看到从零构建网络模型的代码。代码足够间接,而且结构比较完整,因此记录一下学习结果。

本文重点在于如何代码的实现,对于 U-Net 论文中的细节未涉略,关于论文的讨论可移步。

学习的资源链接在文章末尾。

U-net 模型

首先对于模型有一个简单的认识:

U-net 模型
U-net 模型

对于 U-net 模型的构建,主要的在于卷积层和转置卷积(下采样和上采样)的实现,以及如何实现镜像对应部分的连接。请各位读者理解 U-net 模型,并且牢记每一步的通道数。

代码实现

按照工业上或者竞赛上常见的解决问题的步骤,主要包括数据集的获取、模型的构建、模型的训练(损失函数的选择、模型的优化)、训练结果的验证。因此接下来将从这几方面对代码进行解读

数据集的获取

数据集网址:Carvana Image Masking Challenge | Kaggle

百度网链接:https://pan.baidu.com/s/1bhKCyd226__fDhWbYLGPJQ 提取码:4t3y

其中需要读者根据自己的需求先训练集中分出部分的数据用作验证集。

数据集的读取

1
2
3
4
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class CarvanaDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
image_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.gif'))
image = np.array(Image.open(image_path).convert('RGB'))
mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
mask[mask == 255.0] = 1.0

if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations['image']
mask = augmentations['mask']

return image, mask

为了后续操作比较方便,直接继承 Dataset,然后返回 image 和对应的 mask。

os.listdir(path) 返回指定路径下的文件(文件夹),在上面代码中,返回整个训练集图片对应的列表。

os.path.join() 该操作直接获得每一张图片对应的储存路径

image.open().convert(), 该函数将图片按照指定的模式转变图片,例如RGB图像,或者灰度图像。(具体的官方释义我还没找到,如果有官网的解释,请赐教)

mask[mask==255.0] = 1.0 方便后续的 sigmoid()函数的计算?(存疑)

模型的构建

首先观察 U-net 模型的构建,在 pool 层之前,总会有有两次卷积,将原图片的通道数增加。因此首先建立类 DoubleConv。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
def __init__(self,in_channels,out_channels):
super(DoubleConv,self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self,x):
return self.conv(x)

下一步,观察 U-net 模型,由于具有对称性显得格外的优雅,而且每一步的处理显得很有规律,正是因为有这样的规律,因此我们在写代码的时候可以不那么繁琐,重复的卷积-池化-卷积-池化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class UNET(nn.Module):
def __init__(
self, in_channels=3, out_channels=1, features=[64,128,256,512]
):
super(UNET,self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)

for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature

for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2)
)
self.ups.append(DoubleConv(feature*2, feature))

self.bottleneck = DoubleConv(features[-1], features[-1]*2)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

上述代码中,将整个 u-Net 模型分为卷积层(下采样)、转置卷积层(上采样)、池化层、瓶颈层以及最后的卷积层。

对于下采样阶段,使用 ModelList(),然后确定每一次卷积的输入、输出通道,然后使用循环结构。

1
2
3
4
5
6
7
features=[64,128,256,512]

self.downs = nn.ModuleList()

for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def forward(self,x):
skip_connections = []

for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)

x = self.bottleneck(x)
skip_connections = skip_connections[::-1]

for idx in range(0, len(self.ups) ,2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]

if x.shape != skip_connection.shape:
x=TF.resize(x,size=skip_connection.shape[2:])

concat_skip = torch.cat((skip_connection,x), dim=1)
x = self.ups[idx+1](concat_skip)

return self.final_conv(x)

在前向传播的时候,需要注意的是,U-net 每一层都有一个 skip—connnection

skip-connections=[] ,将经过卷积的x保存到列表中,在上采样的时候进行连接

skip_connections=skip_connections[::-1], 保存顺序与使用顺序相反,因此需要反序

concat_skip=torch.cat((skip_connection, x),dim=1) 对两者进行连接

一些实用操作

我觉得我们在写代码的时候,为什么代码结构看的比较凌乱,主要因为我们没有能够将每一个功能、操作整合起来,下面给一个具体的例子。

1
2
3
def save_checkpoint(state,filename='my_checkpoint.pth.tar'):
print('=>Saving checkpoint')
torch.save(state, filename)

将训练模型保存起来的函数

torch.save() 官网torch.save()注释

1
2
3
def load_checkpoint(checkpoint, model):
print('=>Loading checkpoint')
model.load_state_dict(checkpoint['state_dict'])

加载模型,可以将上次未训练完的模型再次进行训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def get_loader(
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers=1,
pin_momory=True,
):
train_ds = CarvanaDataset(
image_dir=train_dir,
mask_dir=train_maskdir,
transform=train_transform
)

train_loader = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_momory,
shuffle=True
)

val_ds = CarvanaDataset(
image_dir=val_dir,
mask_dir=val_maskdir,
transform=val_transform
)

val_loader = DataLoader(
val_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_momory,
shuffle=False
)

return train_loader,val_loader

加载数据的常用函数,其中 CarvanaDataset 自定义,也可以直接使用 Dataset()

DataLoader() 函数中参数:

pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

训练模型

超参数的确定:

1
2
3
4
5
6
7
8
9
10
11
12
13
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKER = 2
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val/"
VAL_MASK_DIR = "data/val_masks/"

训练函数 train_fn()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def train_fn(loader, model, optimizer, loss_fn, scaler):
loop = tqdm(loader)

for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device=DEVICE)
targets = targets.float().unsqueeze(1).to(device=DEVICE)

#forward
'''混合精度训练'''
with torch.cuda.amp.autocast():
preds = model(data)
loss = loss_fn(preds,targets)

#backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

#update tqdm loop
loop.set_postfix(loss=loss.item())

loop = tqdm(loader) 简单理解为快速、可扩展的python进度条

loop.set_postfix() 设置进度条的输出内容

具体关于 tqdm 的使用,本人未深入研究

在上面的代码中,需要注意的是前向传播和反向传播时的代码,与常见的代码不同,因为该代码引入了混合精度训练,具体的请自行查阅。

资源链接

源代码:https://github.com/aladdinpersson/Machine-Learning-Collection

视频资源:https://www.youtube.com/watch?v=IHq1t7NxS8k

哔哩哔哩:【CV教程】从零开始:Pytorch图像分割教程与U-NET_哔哩哔哩_bilibili

一些感想

视频中可以清楚的了解到如何从零开始构建一个模型,如何运行,在使用的过程中一些附加的功能如何实现,对于我这种小白来讲,还是大有裨益的。

而且 up 主的 GitHub 网页上还有许多其他的项目,基本上都是从零开始的,以后可以试试自己去一步步的来参加 kaggle竞赛。

如果还有深度学习刚入门的小伙伴,也可以一起交流学习。

评论