当前位置: 首页 > 工具软件 > SpikingJelly > 使用案例 >

spikingjelly的20201221版本跑通ANN2SNN

刘丰羽
2023-12-01

使用自己的数据集最大的一个问题在于,源代码使用的是MNIST数据集,而我使用的是自己构建的图片集。
主要区别在以下几点:
1.
源代码数据加载使用的方法为:

    # train_data_dataset = torchvision.datasets.MNIST(
    #     root=dataset_dir,
    #     train=True,
    #     transform=torchvision.transforms.ToTensor(),
    #     download=True)
    # train_data_loader = torch.utils.data.DataLoader(
    #     train_data_dataset,
    #     batch_size=batch_size,
    #     shuffle=True,
    #     drop_last=True)
    # test_data_loader = torch.utils.data.DataLoader(
    #     dataset=torchvision.datasets.MNIST(
    #         root=dataset_dir,
    #         train=False,
    #         transform=torchvision.transforms.ToTensor(),
    #         download=True),
    #     batch_size=100,
    #     shuffle=True,
    #     drop_last=False)

使用自己的数据,则使用ImageFolder这种方法:

data_transform = transforms.Compose([
        transforms.Resize(32),  # 等比例转换为32长度
        transforms.CenterCrop(28),  # 中心裁剪为28*28的
        transforms.Grayscale(num_output_channels=1),  # 读取单通道
        transforms.ToTensor(),
        # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    # 读入图片
    train_dataset = datasets.ImageFolder(root='F:\\my_code\\data\\0dB\\train',
                                         transform=data_transform,
                                         )
    train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True
                                               )
    # print(train_dataset[0][0].size())
    test_dataset = datasets.ImageFolder(root='F:\\my_code\\data\\0dB\\test',
                                         transform=data_transform,
                                         )
    test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               )

2.以上通过MNIST转换为ImageFolder的方式后,出现的一个衍生问题就是,源代码中的train_data_dataset里面是有data这个属性的,但是ImageFolder里面就没有。
解决方法:

#print(dataset[0][1])# 第一维是第几张图,第二维为1返回label
#print(dataset[0][0]) # 为0返回图片数据

首先定义一个空的张量norm_set = torch.zeros(norm_set_len,28,28)
我是使用一个for循环来依次取出每个图片,并把这些值赋予到这个空张量里面

    for ii in range(norm_set_len):
        norm_set[ii]=train_dataset[ii][0]

接下来就可以使用啦

 类似资料: