当前位置: 首页 > 工具软件 > Gif123 > 使用案例 >

本地gif图片以图搜图,构建本地gif图库

仲鸿风
2023-12-01

目录

1.项目起因

2.总体思路

3.用到的环境和包

4.具体实现

1.Gif预处理

2.网络调整

3.下载预训练参数并加载

4.提取特征并保存

5.gif搜索

5.结束语

6.完整代码【更新】


1.项目起因

        通过爬虫爬取了一些网站的gif图片,结果等全部爬完后发现图片太多了,根本没工夫一张一张去看。于是打算把这些图片做成一个本地图库,方便需要的时候按照内容检索(主要是有时候找不到来源)。

2.总体思路

  1. 使用神经网络对图片进行特征提取,打算直接采用预训练的ResNet作为特征提取器;
  2. 搜索的时候,通过计算特征向量之间的余弦相似度来确定两张gif之间的相似度。

3.用到的环境和包

  1. Python 3.9;
  2. Pytorch 1.8;
  3. Opencv;
  4. Numpy

4.具体实现

1.Gif预处理

        由于gif图片可以看成是由很多张图片叠在一起的,所以在特征提取的时候我的想法是先按照等间距取一定帧数的图片(我的数据里帧数最少的只有5帧,因此我每张图片都等间隔的取了5帧,特殊情况的用最后一帧补齐5帧),再把这若干张图片叠成一个batch输入到网络中,最后的输出看作是整个gif的特征。

        在实际操作中,发现不能直接通过opencv的imread来读取gif图片,要使用VideoCapture像读取视频文件一样按帧读取。而且我在使用中还发现,使用 CAP_PROP_FRAME_COUNT 不能返回gif的帧数,会返回负数(不知道是我环境的原因还是其他的)。

def gif_split_to(gpath: str, fstep: int) -> list:
    '''
    分割gif图片并挑选特定帧数
    :param gpath: gif路径
    :param fstep: 目标帧数
    :return: 一个列表,包含fstep个帧数
    '''
    # 获取所有帧数
    frames = []
    cap = cv2.VideoCapture(gpath)
    ret, frame = cap.read()
    while ret:
        frames.append(frame)
        ret, frame = cap.read()
    cap.release()
    fnum = len(frames)
    step_frame = math.ceil(fnum / fstep)
    # 防止步长大于帧数总数
    if step_frame <= 0:
        step_frame = 1
    ret = list()
    # 等间距取帧
    for idx in range(0, len(frames)):
        if idx % step_frame == 0 and ret:
            frame = cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB)
            ret.append(frame)
    # 重复最后一帧补齐到目标帧数
    while len(ret) < g_gif_need:
        ret.append(frames[len(frames) - 1])
    frames = None
    return ret

        再把图片列表转为tensor,就可以输入到网络里面了:

def gif_to_tensor(gpath: str, fstep: int) -> torch.FloatTensor:
    '''
    gif转目标维度tensor
    :param gpath:gif路径
    :param fstep:取得目标帧数
    :return:一个tensor
    '''
    list_img = gif_split_to(gpath, fstep)
    gif_np = [np.array(x).transpose((2, 0, 1)) for x in list_img]
    gif_tensor = torch.FloatTensor(np.array(gif_np))
    return gif_tensor

2.网络调整

        前文提到,使用ResNet作为特征提取器,具体我采用的是ResNet18。

        ResNet最后有个全局平均池化和全连接层,是用于分类的,这里只是用它提取特征,所以可以直接去掉。并且我也具体实验了有后三层和没后三层的效果,发现差别还是有的,具体如下:

        去掉全连接层和全局平均池化层,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距是很大的:

        [['1.gif', 1.0], ['2.gif', 0.6843166351318359], ['3.gif', 0.6693203002214432]]

        直接使用原网络,从返回的top3(1.gif是目标图片)结果可以看出,目标图片和非目标图片的相似度差距不是很明显:

        [['1.gif', 1.0], ['2.gif', 0.9977700412273407], ['3.gif', 0.9970085620880127]]

        要调整网络,一种方法是加载网络后,使用下面这种语句修改使用层数:

    net = torchvision.models.resnet50(pretrained=True)
    net = nn.Sequential(*list(resnet_50_s.children())[:-2])

        我使用的方法是直接在源代码上修改,具体是找到torchvison里的ResNet实现,单独放到本地项目里,再在forward里面注释掉:

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 注释掉
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)

        return x

3.下载预训练参数并加载

        预训练参数下载地址可以参考:

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

        使用方法可以参考下面的语句,这里要注意的是下载的参数要用对应的模型去加载,否则会报错

    net = resnet18(pretrained=False).to(device)
    net.load_state_dict(torch.load(g_pth_path)) # g_pth_path 为下载的参数保存地址

4.提取特征并保存

        前面提到把从gif图片中提取的图片数组作为一个batch输入到网络中(一般net的输入shape要是[batch, channel, h, w]),再把获取到的特征转为numpy,使用np.savez_compressed函数保存,这样保存的数组是经过压缩的,可以节省硬盘空间。

    inputs = gif_to_tensor(gif_path, g_gif_need).to(device) # g_gif_need 为每个gif要取的帧数
    out = net(inputs)
    out_np = out.cpu().numpy()
    save_name = str(gif_path).split("\\")[-1].split('.')[0]
    np.savez_compressed(f"./tmp/{save_name}", a=out_np)

5.gif搜索

        要搜索的gif图片也是要经过前面的预处理到特征提取过程,然后加载本地的特征库,一一比对余弦相似度。余弦相似度的比较函数是参考的,具体作者是谁就不知道了:

def mtx_similar1(arr1: np.ndarray, arr2: np.ndarray) -> float:
    '''
    计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。
    注意有展平操作。
    :param arr1:矩阵1
    :param arr2:矩阵2
    :return:实际是夹角的余弦值,ret = (cos+1)/2
    '''
    farr1 = arr1.ravel()
    len1 = len(farr1)
    len2 = len(arr2)
    if len1 > len2:
        farr1 = farr1[:len2]
    else:
        arr2 = arr2[:len1]
    numer = np.sum(farr1 * arr2)
    denom = np.sqrt(np.sum(farr1 ** 2) * np.sum(arr2 ** 2))
    similar = numer / denom
    return (similar + 1) / 2

5.结束语

        程序总体上是粗糙的,还有很多改进的地方,比如:

  1. gif关键帧的提取;
  2. gif的尺寸不是一样的,是否要归一化处理;
  3. 除了余弦相似度,还有其他计算方法吗;
  4. 更合适的网络模型。

完整代码等整理后考虑放到github上。

6.完整代码【更新】

GitHub - ashortname/localGifSearcher: Build a local GIF feature library for search by image.

 类似资料: