当前位置: 首页 > 知识库问答 >
问题:

PyTorch dataloader显示字符串数据集的奇怪行为

梁豪
2023-03-14

我正在处理NLP问题,正在使用PyTorch。由于某些原因,我的数据加载器返回了格式错误的批。我有由句子和整数标签组成的输入数据。这些句子可以是句子列表,也可以是标记列表。稍后我将在下游组件中将标记转换为整数。

list_labels = [ 0, 1, 0]

# List of sentences.
list_sentences = [ 'the movie is terrible',
                   'The Film was great.',
                   'It was just awful.']

# Or list of list of tokens.
list_sentences = [['the', 'movie', 'is', 'terrible'],
                  ['The', 'Film', 'was', 'great.'],
                  ['It', 'was', 'just', 'awful.']]

我创建了以下自定义数据集:

import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, sentences, labels):

        self.sentences = sentences
        self.labels = labels

    def __getitem__(self, i):
        result = {}
        result['sentences'] = self.sentences[i]
        result['label'] = self.labels[i]
        return result

    def __len__(self):
        return len(self.labels)

当我以句子列表的形式提供输入时,数据加载器正确地返回成批完整的句子。请注意,batch\u size=2

list_sentences = [ 'the movie is terrible', 'The Film was great.', 'It was just awful.']
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': ['the movie is terrible', 'The Film was great.'], <-- Great! 2 sentences in batch!
#  'label': tensor([0, 1])}

批次正确地包含两句话和两个标签,因为batch\u size=2

然而,当我将句子作为标记列表的预标记列表输入时,我得到了奇怪的结果:

list_sentences = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.'], ['It', 'was', 'just', 'awful.']]
list_labels = [ 0, 1, 0]


dataset = MyDataset(list_sentences, list_labels)
dataloader = DataLoader(dataset, batch_size=2)

batch = next(iter(dataloader))
print(batch)
# {'sentences': [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')], <-- WHAT?
#  'label': tensor([0, 1])}

请注意,此批的句子是一个包含单词对元组的单一列表。我希望句子是两个列表的列表,如下所示:

{'sentences': [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']

到底是怎么回事?


共有2个答案

缪朝
2023-03-14

另一种解决方案是在数据集中将字符串编码为字节,然后在转发过程中对其进行解码。如果您希望包含元数据的字符串(如数据来自的文件路径),但实际上不需要将数据传递到模型中,这将非常有用。

例如:

class MyDataset(torch.utils.data.Dataset):
    def __next__(self):
        return np.array("this is a sentence").bytes()

然后在向前传球时,你会做:

sentences: List[str] = []
for sentence in batch:
    sentences.append(sentence.decode("ascii"))

印子平
2023-03-14

此行为是因为默认的collate\u fn在必须整理列表时会执行以下操作(这是[“句子”]的情况):

# [...]
elif isinstance(elem, container_abcs.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = zip(*batch)
    return [default_collate(samples) for samples in transposed]

出现“问题”的原因是,在最后两行中,当批处理是容器时,它将递归调用zip(*batch)。序列(和列表是),而zip的行为如下。

如你所见:

batch = [['the', 'movie', 'is', 'terrible'], ['The', 'Film', 'was', 'great.']]
list(zip(*batch))

# [('the', 'The'), ('movie', 'Film'), ('is', 'was'), ('terrible', 'great.')]

在您的案例中,除了实现一个新的collator并将其传递给数据加载程序(…,collate\u fn=mycollator)之外,我看不到任何解决方法。例如,一个简单丑陋的例子可能是:

def mycollator(batch):
    assert all('sentences' in x for x in batch)
    assert all('label' in x for x in batch)
    return {
        'sentences': [x['sentences'] for x in batch],
        'label': torch.tensor([x['label'] for x in batch])
    }

 类似资料:
  • 问题内容: 我有一些奇怪的字符串池行为的问题。我正在使用比较相等的字符串来找出它们是否在池中。 输出为: 这对我来说是一个很大的惊喜。有人可以解释一下吗?我认为这是在编译时发生的。但是,为什么添加到String 根本没有任何区别呢? 问题答案: 是一个编译时常量,而 不是。因此,前者仅编译为字符串常量“ 555”,而后者则编译为实际的方法调用和串联,从而生成一个新的String实例。 另请参见JL

  • 我使用StringTokenizer将字符串分成标记,但是每当出现在中间而没有任何时,就跳过它并取下一个标记 实际产出 想要的Outupt

  • 问题内容: 我已经将其输入python shell: 我期望0.1 * 0.1不是0.01,因为我知道以10为底的0.1是周期性的,以2为底。 我已经看过20个以上的字符,因此我希望能得到20个。为什么我得到4? 好吧,这解释了为什么我给了我4,但为什么回归? 为什么不回合?(我已经阅读了这个答案,但是我想知道他们如何决定何时对浮点数进行取整,以及何时对不进行浮点取整) 因此,浮子的准确性似乎是一

  • 问题内容: 我有一个小文件,其中包含一些我想用“ |”分割的内容 字符。 当我尝试使用其他任何字符(例如“>”)时,它都可以正常工作,但是使用“ |” 性格,有一些意想不到的结果。 行本身(此处带有 >字符) addere> to add>(1) 分割“ >”结果 [加法,加法(1)] 分割“ |” 结果 [,a,d,d,e,r,e,|,t,o,,a,d,d,|,(,1,)] 为什么要拆分所有内容

  • 问题内容: 如何从字符串中删除奇怪的和不需要的Unicode字符(例如带问号的黑色菱形)? 更新: 请告诉我对应于“其中带有问号的黑色菱形”的Unicode字符串或正则表达式。 问题答案: 带问号的黑色菱形不是unicode字符- 它是字体无法显示的字符的占位符。如果字符串中存在一个字形,而该字形不是用于显示该字符串的字体,则将看到占位符。定义为U + FFFD:它的外观取决于您使用的字体。 您可

  • 当标题[currSlotNo]的值为空时,该行将导致错误: } HomeVu1.VdslotService updVideoSlotData pictureInst:331 Picid[currSlotNo]:331 HomeVu1.VdslotService updVideoSlotData Picmk:331 HomeVu1.VdslotService updVideoSlotData aud