目录
通过爬虫爬取了一些网站的gif图片,结果等全部爬完后发现图片太多了,根本没工夫一张一张去看。于是打算把这些图片做成一个本地图库,方便需要的时候按照内容检索(主要是有时候找不到来源)。
由于gif图片可以看成是由很多张图片叠在一起的,所以在特征提取的时候我的想法是先按照等间距取一定帧数的图片(我的数据里帧数最少的只有5帧,因此我每张图片都等间隔的取了5帧,特殊情况的用最后一帧补齐5帧),再把这若干张图片叠成一个batch输入到网络中,最后的输出看作是整个gif的特征。
在实际操作中,发现不能直接通过opencv的imread来读取gif图片,要使用VideoCapture像读取视频文件一样按帧读取。而且我在使用中还发现,使用 CAP_PROP_FRAME_COUNT 不能返回gif的帧数,会返回负数(不知道是我环境的原因还是其他的)。
def gif_split_to(gpath: str, fstep: int) -> list:
'''
分割gif图片并挑选特定帧数
:param gpath: gif路径
:param fstep: 目标帧数
:return: 一个列表,包含fstep个帧数
'''
# 获取所有帧数
frames = []
cap = cv2.VideoCapture(gpath)
ret, frame = cap.read()
while ret:
frames.append(frame)
ret, frame = cap.read()
cap.release()
fnum = len(frames)
step_frame = math.ceil(fnum / fstep)
# 防止步长大于帧数总数
if step_frame <= 0:
step_frame = 1
ret = list()
# 等间距取帧
for idx in range(0, len(frames)):
if idx % step_frame == 0 and ret:
frame = cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB)
ret.append(frame)
# 重复最后一帧补齐到目标帧数
while len(ret) < g_gif_need:
ret.append(frames[len(frames) - 1])
frames = None
return ret
再把图片列表转为tensor,就可以输入到网络里面了:
def gif_to_tensor(gpath: str, fstep: int) -> torch.FloatTensor:
'''
gif转目标维度tensor
:param gpath:gif路径
:param fstep:取得目标帧数
:return:一个tensor
'''
list_img = gif_split_to(gpath, fstep)
gif_np = [np.array(x).transpose((2, 0, 1)) for x in list_img]
gif_tensor = torch.FloatTensor(np.array(gif_np))
return gif_tensor
前文提到,使用ResNet作为特征提取器,具体我采用的是ResNet18。
ResNet最后有个全局平均池化和全连接层,是用于分类的,这里只是用它提取特征,所以可以直接去掉。并且我也具体实验了有后三层和没后三层的效果,发现差别还是有的,具体如下:
去掉全连接层和全局平均池化层,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距是很大的:
[['1.gif', 1.0], ['2.gif', 0.6843166351318359], ['3.gif', 0.6693203002214432]]
直接使用原网络,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距不是很明显:
[['1.gif', 1.0], ['2.gif', 0.9977700412273407], ['3.gif', 0.9970085620880127]]
要调整网络,一种方法是加载网络后,使用下面这种语句修改使用层数:
net = torchvision.models.resnet50(pretrained=True)
net = nn.Sequential(*list(resnet_50_s.children())[:-2])
我使用的方法是直接在源代码上修改,具体是找到torchvison里的ResNet实现,单独放到本地项目里,再在forward里面注释掉:
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 注释掉
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
预训练参数下载地址可以参考:
model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', }
使用方法可以参考下面的语句,这里要注意的是下载的参数要用对应的模型去加载,否则会报错
net = resnet18(pretrained=False).to(device)
net.load_state_dict(torch.load(g_pth_path)) # g_pth_path 为下载的参数保存地址
前面提到把从gif图片中提取的图片数组作为一个batch输入到网络中(一般net的输入shape要是[batch, channel, h, w]),再把获取到的特征转为numpy,使用np.savez_compressed函数保存,这样保存的数组是经过压缩的,可以节省硬盘空间。
inputs = gif_to_tensor(gif_path, g_gif_need).to(device) # g_gif_need 为每个gif要取的帧数
out = net(inputs)
out_np = out.cpu().numpy()
save_name = str(gif_path).split("\\")[-1].split('.')[0]
np.savez_compressed(f"./tmp/{save_name}", a=out_np)
要搜索的gif图片也是要经过前面的预处理到特征提取过程,然后加载本地的特征库,一一比对余弦相似度。余弦相似度的比较函数是参考的,具体作者是谁就不知道了:
def mtx_similar1(arr1: np.ndarray, arr2: np.ndarray) -> float:
'''
计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。
注意有展平操作。
:param arr1:矩阵1
:param arr2:矩阵2
:return:实际是夹角的余弦值,ret = (cos+1)/2
'''
farr1 = arr1.ravel()
len1 = len(farr1)
len2 = len(arr2)
if len1 > len2:
farr1 = farr1[:len2]
else:
arr2 = arr2[:len1]
numer = np.sum(farr1 * arr2)
denom = np.sqrt(np.sum(farr1 ** 2) * np.sum(arr2 ** 2))
similar = numer / denom
return (similar + 1) / 2
程序总体上是粗糙的,还有很多改进的地方,比如:
完整代码等整理后考虑放到github上。
GitHub - ashortname/localGifSearcher: Build a local GIF feature library for search by image.