当前位置: 首页 > 知识库问答 >
问题:

在ResNet50中对pytorch中包含10个类的图像进行分类时遇到此错误。我的代码是:

郤瀚
2023-03-14

这是我正在实现的代码:我使用CalTech256数据集的子集对10种不同动物的图像进行分类。我们将介绍数据集准备、数据扩充,然后介绍构建分类器的步骤。

def train_and_validate(model, loss_criterion, optimizer, epochs=25):
    '''
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param loss_criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)

    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    '''

    start = time.time()
    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))

        # Set to training mode
        model.train()

        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0

        valid_loss = 0.0
        valid_acc = 0.0

        for i, (inputs, labels) in enumerate(train_data_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)

            # Clean existing gradients
            optimizer.zero_grad()

            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)

            # Compute loss
            loss = loss_criterion(outputs, labels)

            # Backpropagate the gradients
            loss.backward()

            # Update the parameters
            optimizer.step()

            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)

            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)

            #print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))


        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_data_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

                #print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))

        # Find average training loss and training accuracy
        avg_train_loss = train_loss/train_data_size 
        avg_train_acc = train_acc/train_data_size

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/valid_data_size 
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])

        epoch_end = time.time()

        print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))

        # Save if the model has best accuracy till now
        torch.save(model, dataset+'_model_'+str(epoch)+'.pt')

    return model, history

# Load pretrained ResNet50 Model
resnet50 = models.resnet50(pretrained=True)
#resnet50 = resnet50.to('cuda:0')


# Freeze model parameters
for param in resnet50.parameters():
    param.requires_grad = False
# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features

resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes), # Since 10 possible outputs
    nn.LogSoftmax(dim=1) # For using NLLLoss()
)

# Convert model to be used on GPU
# resnet50 = resnet50.to('cuda:0')

# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features

resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes), # Since 10 possible outputs
    nn.LogSoftmax(dienter code herem=1) # For using NLLLoss()
)

# Convert model to be used on GPU
# resnet50 = resnet50.to('cuda:0')`enter code here`

错误是:

运行时错误回溯(最近一次调用last)in()6#为25个历元7个历元=30训练模型----

列车内和验证(模型、损耗标准、优化器、EPOCH)43 44#计算损耗---

~\anaconda3\lib\site-包\torch\nn\模块\module.pyhtml" target="_blank">调用(自我,*输入,**kwargs)539结果=自我。_slow_forward(*输入,**kwargs)540其他:-

~\anaconda3\lib\site-包\torch\nn\模块\loss.py在前进(自我,输入,目标)202 203 def前进(自我,输入,目标):-

~\Anaconda3\lib\site packages\torch\nn\functional。nll_损失中的py(输入、目标、重量、尺寸平均值、忽略指数、减少、减少)1836。格式(input.size(0)、target.size(0)))1837如果dim==2:-

运行时错误:断言'cur\u target

共有1个答案

居和顺
2023-03-14

当数据集中有不正确的标签,或者标签是1索引(而不是0索引)时,就会发生这种情况。从错误消息中,cur_target必须小于类的总数(10)。若要验证此问题,请检查数据集中的最大和最小标签。如果数据确实是1索引的,只需从所有注释中减去一个,就可以了。

注意,另一个可能的原因是数据中存在一些-1标签。一些(特别是较旧的)数据集使用-1作为错误/可疑标签的指示。如果你发现这样的标签,丢弃就好了。

 类似资料:
  • 渲染图像时,我会得到下面的运行时错误。奇怪的是,图像在设备上显示得很好,在其他页面上显示时没有出现此错误。错误消息没有提供任何有用的信息。 下面是一段代码,后跟转储:

  • 问题内容: 我想创建一个PHP类,可以说Myclass.php。现在在该类中,我只想定义类本身和一些实例变量。但是所有方法都必须来自Myclass_methods.php文件。我可以只将该文件包含到班级正文中吗? 我有充分的理由为什么要分开这个。简而言之,我将拥有一个后端,在其中可以更改类的业务逻辑,而所有其他内容必须保持不变。系统为我维护所有ORM和其他内容。 但是,如果这不是一个好主意,则最好

  • 错误:将字节码转换为dex时出错:原因:com . Android . dex . dex异常:多个dex文件定义了Lcom/Google/Android/GMS/internal/measurement/zza bn;:app:transformClassesWithDexForDebug失败错误:任务执行失败”:app:transformClassesWithDexForDebug。com .

  • 当我试图运行这个骨架代码时,我一直收到这个错误。我试图在Eclipse中使用OpenGL。我不知道是什么导致了这个问题。我如何解决这个问题?我也已经将jar文件添加到用户库中。 代码: 这是我一直在犯的错误。 错误:错误1 错误2 Plhd-19/>(JComponent. java: 4839)在java.桌面/java. awt.容器. addNotify(容器. java: 2804)在ja

  • 本文向大家介绍使用Keras预训练模型ResNet50进行图像分类方式,包括了使用Keras预训练模型ResNet50进行图像分类方式的使用技巧和注意事项,需要的朋友参考一下 Keras提供了一些用ImageNet训练过的模型:Xception,VGG16,VGG19,ResNet50,InceptionV3。在使用这些模型的时候,有一个参数include_top表示是否包含模型顶部的全连接层,如