1 #-*- coding: utf-8 -*-
2 #Implementation of https://arxiv.org/pdf/1512.03385.pdf/
3 #See section 4.2 for model architecture on CIFAR-10.
4 #Some part of the code was referenced below.
5 #https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
6
7 importos8 from PIL importImage9 importtime10
11 importtorch12 importtorch.nn as nn13 importtorchvision.datasets as dsets14 importtorchvision.transforms as transforms15 from torch.autograd importVariable16 importtorch.utils.data as data17 from torch.nn importDataParallel18
19
20 kwargs = {'num_workers': 1, 'pin_memory': True}21 #def my dataloader, return the data and corresponding label
22
23
24 defdefault_loader(path):25 return Image.open(path).convert('RGB')26
27
28 class myImageFloder(data.Dataset): #Class inheritance
29 def __init__(self, root, label, transform=None, target_transform=None, loader=default_loader):30 fh =open(label)31 c =032 imgs =[]33 class_names =[]34 for line infh.readlines():35 if c ==0:36 class_names = [n.strip() for n in line.rstrip().split(' ')]37 else:38 cls = line.split() #cls is a list
39 fn =cls.pop(0)40 ifos.path.isfile(os.path.join(root, fn)):41 imgs.append((fn, tuple([float(v) for v in cls]))) #imgs is the list,and the content is the tuple
42 #we can use the append way to append the element for list
43 c = c + 1
44 self.root =root45 self.imgs =imgs46 self.classes =class_names47 self.transform =transform48 self.target_transform =target_transform49 self.loader =loader50
51 def __getitem__(self, index):52 fn, label = self.imgs[index] #eventhough the imgs is just a list, it can return the elements of is
53 #in a proper way
54 img =self.loader(os.path.join(self.root, fn))55 if self.transform is notNone:56 img =self.transform(img)57 returnimg, torch.Tensor(label)58
59 def __len__(self):60 returnlen(self.imgs)61
62 defgetName(self):63 returnself.classes64
65 mytransform = transforms.Compose([transforms.ToTensor()]) #almost dont do any operation
66 train_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/Training"
67 test_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/Testing"
68 train_label = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/NameList_train.txt"
69 test_label = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/NameList_test.txt"
70
71 train_loader =torch.utils.data.DataLoader(72 myImageFloder(root=train_data_root, label=train_label, transform=mytransform),73 batch_size=64, shuffle=True, **kwargs)74
75 test_loader =torch.utils.data.DataLoader(76 myImageFloder(root=test_data_root, label=test_label, transform=mytransform),77 batch_size=64, shuffle=True, **kwargs)78
79
80 #3x3 Convolution
81 def conv3x3(in_channels, out_channels, stride=1):82 return nn.Conv2d(in_channels, out_channels, kernel_size=3,83 stride=stride, padding=1, bias=False)84
85
86 #Residual Block
87 classResidualBlock(nn.Module):88 def __init__(self, in_channels, out_channels, stride=1, downsample=None):89 super(ResidualBlock, self).__init__()90 self.conv1 = conv3x3(in_channels, out_channels, stride) #kernel size is default 3
91 self.bn1 =nn.BatchNorm2d(out_channels)92 self.relu = nn.ReLU(inplace=True)93 self.conv2 =conv3x3(out_channels, out_channels)94 self.bn2 =nn.BatchNorm2d(out_channels)95 self.downsample =downsample96
97 defforward(self, x):98 residual =x99 out =self.conv1(x)100 out =self.bn1(out)101 out =self.relu(out)102 out =self.conv2(out)103 out =self.bn2(out)104 ifself.downsample:105 residual =self.downsample(x)106 out +=residual107 out =self.relu(out)108 returnout109
110
111 #ResNet Module
112 classResNet(nn.Module):113 def __init__(self, block, layers, num_classes=1):114 super(ResNet, self).__init__()115 self.in_channels = 16
116 self.conv = conv3x3(3, 16)117 self.bn = nn.BatchNorm2d(16)118 self.relu = nn.ReLU(inplace=True)119 self.layer1 = self.make_layer(block, 16, layers[0])120 self.layer2 = self.make_layer(block, 32, layers[0], 2)121 self.layer3 = self.make_layer(block, 64, layers[1], 2) #the input arg is blocks and the stride
122 self.layer4 = self.make_layer(block, 128, layers[1], 2)123 self.layer5 = self.make_layer(block, 256, layers[1], 2)124 self.avg_pool = nn.AvgPool2d(kernel_size=8,stride=8) #2*2
125 self.fc = nn.Linear(256*2*2, num_classes)126
127 def make_layer(self, block, out_channels, blocks, stride=1):128 downsample =None129 if (stride != 1) or (self.in_channels != out_channels): #the input channel is not consistant with the output's
130 downsample = nn.Sequential( #do the downsample, def a conv, for example: 256*256*16 -> 128*128*32
131 conv3x3(self.in_channels, out_channels, stride=stride),132 nn.BatchNorm2d(out_channels))133 layers =[]134 layers.append(block(self.in_channels, out_channels, stride, downsample))135 self.in_channels = out_channels #update the input channel and the output channel
136 for i in range(1, blocks): #reduce a block because the first block is already appened
137 layers.append(block(out_channels, out_channels)) #32*32 -> 8*8
138 return nn.Sequential(*layers)139
140 defforward(self, x):141 out =self.conv(x)142 out =self.bn(out)143 out =self.relu(out)144 out =self.layer1(out)145 out =self.layer2(out)146 out =self.layer3(out)147 out=self.layer4(out)148 out=self.layer5(out)149 out =self.avg_pool(out)150 out = out.view(out.size(0), -1)151 out =self.fc(out)152 returnout153
154
155 resnet = DataParallel(ResNet(ResidualBlock, [3, 3, 3]))156 resnet.cuda()157
158 #Loss and Optimizer
159 criterion =nn.MSELoss()160 lr = 0.001
161 optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)162
163 #Training
164 start=time.clock()165 for epoch in range(50):166 for i, (images, labels) inenumerate(train_loader):167 images =Variable(images.cuda())168 labels =Variable(labels.cuda())169
170 #Forward + Backward + Optimize
171 optimizer.zero_grad()172 outputs =resnet(images)173 loss =criterion(outputs, labels)174 loss.backward()175 optimizer.step()176
177 if (i + 1) % 100 ==0:178 print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" % (epoch + 1, 80, i + 1, 500, loss.data[0]))179
180 #Decaying Learning Rate
181 if (epoch + 1) % 20 ==0:182 lr /= 3
183 optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)184 elapsed=time.clock()-start185 print("time used:",elapsed)
186 ## Test
187 #correct = 0
188 #total = 0
189 #for images, labels in test_loader:
190 #images = Variable(images.cuda())
191 #outputs = resnet(images)
192 #_, predicted = torch.max(outputs.data, 1)
193 #total += labels.size(0)
194 #correct += (predicted.cpu() == labels).sum()
195 #196 #print('Accuracy of the model on the test images: %d %%' % (100 * correct / total))
197
198 #Save the Model
199 torch.save(resnet.state_dict(), 'resnet.pkl')