当前位置: 首页 > 知识库问答 >
问题:

python - 飞桨参数类型报错?

羊舌诚
2024-05-20

class Reader(Dataset):

def __init__(self, data_path: str, is_val: bool = False):    super().__init__()    self.data_path = data_path    with open(os.path.join(self.data_path, "label_dict.txt"), "r", encoding="utf-8") as f:        self.info = ast.literal_eval(f.read())    self.img_paths = [os.path.join(self.data_path, img_name) for img_name in self.info]    self.img_paths = self.img_paths[-250:] if is_val else self.img_paths[:-250]def __getitem__(self, index):    file_path = self.img_paths[index]    file_name = os.path.basename(file_path)    with open(file_path, 'rb') as f:        img = Image.open(f).convert('RGB')        img = np.array(img, dtype="float32") / 255        img = img.reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W))    label = [CHAR_TO_IDX[char] for char in self.info[file_name]]    label = np.array(label, dtype="int64")  # 确保 label 是 int64 类型    label_length = len(label)    input_length = np.array([IMAGE_SHAPE_W], dtype="int64")  # 确保 input_length 是 int64 类型    return img, label, label_length, input_lengthdef __len__(self):    return len(self.img_paths)

...

class CTCLoss(paddle.nn.Layer):

def __init__(self):    super().__init__()def forward(self, ipt, label, label_lengths, input_lengths):    # 转换 ipt 的维度顺序,并确保是 float32 类型    ipt = paddle.transpose(ipt, perm=[1, 0, 2])    ipt = paddle.cast(ipt, 'float32')    # 确保 label, label_lengths, 和 input_lengths 是 int64 类型    label = paddle.to_tensor(label, dtype='int64')    label_lengths = paddle.to_tensor(label_lengths, dtype='int64')    input_lengths = paddle.to_tensor(input_lengths, dtype='int64')    # 计算损失,确保 blank 索引正确    loss = paddle.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=0)    return loss

来回报错:
ValueError: (InvalidArgument) The type of data we are trying to retrieve (int32) does not match the type of data (int64) currently contained in the container.
[Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():9 != phi::CppTypeToDataType<T>::Type():7.] (at ..\paddle\phi\core\dense_tensor.cc:171)
修改完要输入的类型为int32后运行又报错
ValueError: (InvalidArgument) The type of data we are trying to retrieve (int64) does not match the type of data (int32) currently contained in the container.
[Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():7 != phi::CppTypeToDataType<T>::Type():9.] (at ..\paddle\phi\core\dense_tensor.cc:185)
我又改回相关参数为int64后又报错

形成了一个闭环,求大佬解决,这是一个通过飞桨训练验证码识别模型的代码

共有1个答案

漆雕欣德
2024-05-20

根据您提供的代码和错误信息,您遇到的问题与数据类型的转换和匹配有关。飞桨(PaddlePaddle)框架期望某些操作或函数的输入参数具有特定的数据类型。在您的例子中,您可能在某个地方将数据类型错误地设置为了 int32,而飞桨期望的是 int64

错误提示表明,在某些地方,您正在尝试获取或操作与容器当前包含的数据类型不匹配的数据。这通常发生在将 NumPy 数组转换为 PaddlePaddle 张量时,或者在执行某些需要特定数据类型的操作时。

让我们看看如何修复这个问题。首先,请确保您的 Reader 类中的 labelinput_length 变量在转换为 PaddlePaddle 张量之前已经是 int64 类型。您的代码中已经包含了将这两个变量转换为 int64 的步骤,这是正确的。

然后,在 CTCLoss 类的 forward 方法中,您也确保了 iptlabellabel_lengthsinput_lengths 都被转换为了正确的类型。但是,错误可能发生在其他地方。

请检查以下几点:

  1. 确保您的全局变量 IMAGE_SHAPE_CIMAGE_SHAPE_HIMAGE_SHAPE_WCHAR_TO_IDXReader 类之外被正确定义,并且它们的值没有问题。
  2. Reader 类的 __getitem__ 方法中,您已经将 labelinput_length 转换为了 int64 类型的 NumPy 数组。当您将这些数组传递给 CTCLoss 类的 forward 方法时,您又将它们转换为了 PaddlePaddle 张量,并确保它们是 int64 类型。这是正确的。
  3. 检查您的模型其他地方是否还有数据类型不匹配的问题。特别是,检查您在训练循环中如何处理 imglabellabel_lengthinput_length。确保在将它们传递给模型或损失函数之前,它们的数据类型是正确的。
  4. 如果您在训练循环中使用了 DataLoader,请检查 DataLoader 是否正确地处理了数据类型。有时,DataLoader 可能会自动将数据转换为不同的类型,这可能会导致类型不匹配的问题。

如果您已经检查了以上所有点,并且仍然遇到问题,那么请提供更详细的代码和完整的错误堆栈信息,以便进一步分析问题。

此外,请确保您的 PaddlePaddle 版本是最新的,或者至少是一个稳定的版本,因为旧版本可能包含已知的 bug 或不稳定的特性。

最后,如果您仍然无法解决问题,请考虑在 PaddlePaddle 的官方论坛或 GitHub 仓库上寻求帮助,那里有许多经验丰富的开发者和用户可以帮助您解决问题。

 类似资料:
  • 本文向大家介绍Python函数参数类型*、**的区别,包括了Python函数参数类型*、**的区别的使用技巧和注意事项,需要的朋友参考一下 刚开始学习python,python相对于java确实要简洁易用得多。内存回收类似hotspot的可达性分析, 不可变对象也如同java得Integer类型,with函数类似新版本C++的特性,总体来说理解起来比较轻松。只是函数部分参数的"*"与"**",闭包

  • 一、泛型 Scala 支持类型参数化,使得我们能够编写泛型程序。 1.1 泛型类 Java 中使用 <> 符号来包含定义的类型参数,Scala 则使用 []。 class Pair[T, S](val first: T, val second: S) { override def toString: String = first + ":" + second } object ScalaAp

  • 我试图理解为什么当我在MyModel中为T使用更高类型的参数时,以下代码无法编译 但是如果我把它改成< code > new Bar[my model[Any]]它就会编译。这是为什么呢?

  • 问题内容: 有时需要检查Python中的参数。例如,我有一个函数可以接受网络中其他节点的地址作为原始字符串地址,也可以接受封装其他节点信息的类Node。 我使用type()函数,如下所示: 这是这样做的好方法吗? 更新1: Python 3具有函数参数的注释。可以使用以下工具将其用于类型检查:http ://mypy-lang.org/ 问题答案: 使用。样品:

  • 问题内容: 我有一个方法以a 作为参数。 在中,我如何知道a 是还是a 是? 问题答案: 根据用户omain的回答“如果使用<?>,则意味着您将不会在任何地方使用参数化类型。要么转到特定类型(在您的情况下,似乎是),要么转到非常通用的“ 另外,我相信如果您使用问号,编译器将在运行时(类型;有效Java的第119页)消除类型不匹配的情况,绕过擦除,并有效地消除了使用泛型类型所带来的好处? 要回答发问

  • 我对这个flutter简单图表代码有问题。在我尝试运行代码时显示此错误。有人能帮我吗...... 参数类型'List 这是代码示例: