bert 中文基于文本的问答系统

冷宏茂
2023-12-01

bert 中文基于文本的问答系统

# -!- coding: utf-8 -!-
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('there are %d GPU(s) available.'% torch.cuda.device_count())
    print('we will use the GPU: ', torch.cuda.get_device_name(0))
else:
    print('No GPU availabel, using the CPU instead.')
    device = torch.device('cpu')
there are 1 GPU(s) available.
we will use the GPU:  GeForce GTX 1070
import os
import time
import json
import random
import datetime
import numpy as np
from tqdm import tqdm
from transformers import AdamW
from torch.utils.tensorboard import SummaryWriter
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader,random_split
from transformers import WEIGHTS_NAME, CONFIG_NAME

from transformers import (
    DataProcessor,
    BertTokenizer,
    squad_convert_examples_to_features,
    BertForQuestionAnswering,
)
# 设置随机种子.
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

定义DataProcessor

class MySquadProcessor(DataProcessor):
    def get_train_examples(self, data_dir, filename=None):
        """
        Returns the training examples from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the training file has a different name than the original one
                which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.

        """
        if data_dir is None:
            data_dir = ""

        if self.train_file is None:
            raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")

        with open(
            os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
        ) as reader:
            input_data = json.load(reader)["data"]
        return self._create_examples(input_data, "train")

    def get_dev_examples(self, data_dir, filename=None):
        """
        Returns the evaluation example from the data directory.

        Args:
            data_dir: Directory containing the data files used for training and evaluating.
            filename: None by default, specify this if the evaluation file has a different name than the original one
                which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
        """
        if data_dir is None:
            data_dir = ""

        if self.dev_file is None:
            raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")

        with open(
            os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
        ) as reader:
            input_data = json.load(reader)["data"]
        return self._create_examples(input_data, "dev")

    def _create_examples(self, input_data, set_type):
        is_training = set_type == "train"
        examples = []
        for entry in tqdm(input_data):
            title = entry["title"]
            for paragraph in entry["paragraphs"]:
                context_text = paragraph["context"]
                for qa in paragraph["qas"]:
                    qas_id = qa["id"]
                    question_text = qa["question"]
                    start_position_character = None
                    answer_text = None
                    answers = []

                    if "is_impossible" in qa:
                        is_impossible = qa["is_impossible"]
                    else:
                        is_impossible = False

                    if not is_impossible:
                        answer = qa["answers"][0]
                        answer_text = qa["answers"][0]["text"]
                        start_position_character = qa["answers"][0]["answer_start"]  

                    example = ChineseSquadExample(
                        qas_id=qas_id,
                        question_text=question_text,
                        context_text=context_text,
                        answer_text=answer_text,
                        start_position_character = start_position_character,
                        title=title,
                        is_impossible=is_impossible,
                        answers=answers,
                    )

                    examples.append(example)
        return examples

导入json 文件

class SquadV3Processor(MySquadProcessor):
    train_file = "train-v2.0.json"
    dev_file = "dev-v2.0.json"

定义 中文的SquadExample (中文的SquadExample 和 英文的SquadExample不同,所有我们要自己编写)

class ChineseSquadExample(object):
    """
    A single training/test example for the Squad dataset, as loaded from disk.

    Args:
        qas_id: The example's unique identifier
        question_text: The question string
        context_text: The context string
        answer_text: The answer string
        start_position_character: The character position of the start of the answer
        title: The title of the example
        answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
        is_impossible: False by default, set to True if the example has no possible answer.
    """

    def __init__(
        self,
        qas_id,
        question_text,
        context_text,
        answer_text,
        start_position_character,
        title,
        answers=[],
        is_impossible=False,
    ):

        self.qas_id = qas_id
        self.question_text = question_text
        self.context_text = context_text.replace(" ","").replace("  ","").replace(" ","")
        self.answer_text =""
        for e in answer_text.replace(" ","").replace("  ","").replace(" ",""):
            self.answer_text += e
            self.answer_text +=" "
        self.answer_text = self.answer_text[0:-1]

        self.title = title
        self.is_impossible = is_impossible
        self.answers = answers
        self.doc_tokens = [e for e in self.context_text]
        self.char_to_word_offset = [i for i, e in enumerate(self.context_text)]
        self.start_position = self.context_text.find(answer_text.replace(" ","").replace("  ","").replace(" ",""))
        self.end_position = self.start_position + len(answer_text.replace(" ","").replace("  ","").replace(" ",""))

定义计时函数

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    # 返回 hh:mm:ss 形式的时间
    return str(datetime.timedelta(seconds=elapsed_rounded))

定义训练函数

def training(train_dataloader, model):
    t0 = time.time()
    total_train_loss = 0
    total_train_accuracy = 0
    model.train()

    for step, batch in enumerate(train_dataloader):

        # 每隔40个batch 输出一下所用时间.
        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # `batch` 包括5个 tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: token_type_ids
        #   [3]: start_positions
        #   [4]: end_positions

        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        token_type_ids = batch[2].to(device)
        start_positions = batch[3].to(device)
        end_positions = batch[4].to(device)

        # 清空梯度
        model.zero_grad()
        # forward
        # 参考 https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
        loss, start_scores, end_scores = model(input_ids, attention_mask=attention_mask,
                                               token_type_ids=token_type_ids, start_positions=start_positions,
                                               end_positions=end_positions)
        total_train_loss += loss.item()

        # backward 更新 gradients.
        loss.backward()

        # 减去大于1 的梯度,将其设为 1.0, 以防梯度爆炸.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # 更新模型参数
        optimizer.step()

        # 更新 learning rate.
        scheduler.step()

        # 计算batches的平均损失.
    avg_train_loss = total_train_loss / len(train_dataloader)

    print("  平均训练损失 loss: {0:.2f}".format(avg_train_loss))
    return  avg_train_loss

定义校验函数

def train_evalution(test_dataloader,model):

    total_eval_loss = 0
    model.eval()

    for batch in test_dataloader:

        # `batch` 包括5个 tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: token_type_ids
        #   [3]: start_positions
        #   [4]: end_positions

        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        token_type_ids = batch[2].to(device)
        start_positions = batch[3].to(device)
        end_positions = batch[4].to(device)

        # 在valuation 状态,不更新权值,不改变计算图
        with torch.no_grad():
            # 参考 https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            loss, start_scores, end_scores = model(input_ids, attention_mask=attention_mask,
                                                   token_type_ids=token_type_ids, start_positions=start_positions,
                                                   end_positions=end_positions)

        # 计算 validation loss.
        total_eval_loss += loss.item()

    return total_eval_loss,len(test_dataloader)

读数据

#if __name__ == '__main__':    
data_dir = ".//data//"
processor = SquadV3Processor()
Train_data = processor.get_train_examples(data_dir)
Dev_data = processor.get_dev_examples(data_dir)

tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
model = BertForQuestionAnswering.from_pretrained('hfl/chinese-roberta-wwm-ext')
model.to(device)

max_seq_length = 1280
max_query_length = 128
100%|██████████████████████████████████████████████████████████████████████████████| 848/848 [00:00<00:00, 1462.47it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 848/848 [00:00<00:00, 1532.02it/s]
Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
is_training = True
Train_features,Train_dataset = squad_convert_examples_to_features(
    examples=Train_data[0:10],
    tokenizer=tokenizer,
    max_seq_length= max_seq_length,
    doc_stride= True,
    max_query_length= max_query_length,
    is_training=is_training,
    return_dataset='pt',
)
convert squad examples to features:   0%|                                                       | 0/10 [00:00<?, ?it/s]
is_training = False
Dev_features,Dev_dataset = squad_convert_examples_to_features(
    examples=Dev_data[0:10],
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    doc_stride=True,
    max_query_length=max_query_length,
    is_training=is_training,
    return_dataset='pt',
)

设计dataloader

train_dataloader = DataLoader(Train_dataset, batch_size=1, shuffle=True)
dev_dataloader = DataLoader(Dev_dataset, batch_size=1, shuffle=True)

设置模型参数

# AdamW 是一个 huggingface library 的类,'W' 是'Weight Decay fix"的意思。
optimizer = AdamW(model.parameters(),
                  lr=2e-5,  # args.learning_rate - 默认是 5e-5
                  eps=1e-8  # args.adam_epsilon  - 默认是 1e-8, 是为了防止衰减率分母除到0
                  )

# bert 推荐 epochs 在2到4之间为好。
epochs = 2

# training steps 的数量: [number of batches] x [number of epochs].
total_steps = len(train_dataloader) * epochs

# 设计 learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,  # Default value in run_glue.py
                                            num_training_steps=total_steps)

训练模型

output_dir = "./this_model/"
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
writer = SummaryWriter("./log_models/")

# 设置总时间.
total_t0 = time.time()

for epoch_i in range(0, epochs):
    print('Epoch {:} / {:}'.format(epoch_i + 1, epochs))

    # ========================================
    #               training
    # ========================================
    t0 = time.time()
    avg_train_loss = training(train_dataloader, model)
    # 计算训练时间.
    training_time = format_time(time.time() - t0)
    print("  训练时间: {:}".format(training_time))

    # ========================================
    #               Validation
    # ========================================

    t0 = time.time()

    total_eval_loss, valid_dataloader_length = train_evalution(dev_dataloader, model)

    print("") 

    # 计算batches的平均损失.
    avg_val_loss = total_eval_loss / valid_dataloader_length

    # 计算validation 时间.
    validation_time = format_time(time.time() - t0)

    print("  平均测试损失 Loss: {0:.2f}".format(avg_val_loss))
    print("  测试时间: {:}".format(validation_time))

    writer.add_scalars(f'Acc/Loss', {
        'Training Loss': avg_train_loss,
        'Valid Loss': avg_val_loss,

    }, epoch_i + 1)

print("训练一共用了 {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))
writer.close()
torch.save(model.state_dict(), output_model_file)
model.config.to_json_file(output_config_file)

由于本人GPU 只要8个g, 训练这个模型是非常困难,下面是在pycharm上训练的结果。由于训练时间比较短,训练效果不佳。

there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:12:29.425728: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
100%|██████████| 2402/2402 [00:01<00:00, 1612.06it/s]
100%|██████████| 848/848 [00:00<00:00, 2078.73it/s]
there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:12:47.730465: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
convert squad examples to features: 100%|██████████| 10137/10137 [05:29<00:00, 30.75it/s]
add example index and unique id: 100%|██████████| 10137/10137 [00:00<00:00, 781775.82it/s]
there are 1 GPU(s) available.
we will use the GPU: GeForce GTX 1070
2020-08-04 09:18:27.213451: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
convert squad examples to features: 100%|██████████| 3219/3219 [01:46<00:00, 30.22it/s]
add example index and unique id: 100%|██████████| 3219/3219 [00:00<00:00, 807021.19it/s]
Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForQuestionAnswering: [‘cls.predictions.bias’, ‘cls.predictions.transform.dense.weight’, ‘cls.predictions.transform.dense.bias’, ‘cls.predictions.transform.LayerNorm.weight’, ‘cls.predictions.transform.LayerNorm.bias’, ‘cls.predictions.decoder.weight’, ‘cls.seq_relationship.weight’, ‘cls.seq_relationship.bias’]

  • This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
  • This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: [‘qa_outputs.weight’, ‘qa_outputs.bias’]
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Epoch 1 / 2
Batch 400 of 10,137. Elapsed: 0:03:02.
Batch 800 of 10,137. Elapsed: 0:06:04.
Batch 1,200 of 10,137. Elapsed: 0:09:06.
Batch 1,600 of 10,137. Elapsed: 0:12:08.
Batch 2,000 of 10,137. Elapsed: 0:15:11.
Batch 2,400 of 10,137. Elapsed: 0:18:14.
Batch 2,800 of 10,137. Elapsed: 0:21:20.
Batch 3,200 of 10,137. Elapsed: 0:25:02.
Batch 3,600 of 10,137. Elapsed: 0:28:10.
Batch 4,000 of 10,137. Elapsed: 0:31:45.
Batch 4,400 of 10,137. Elapsed: 0:34:57.
Batch 4,800 of 10,137. Elapsed: 0:38:36.
Batch 5,200 of 10,137. Elapsed: 0:42:16.
Batch 5,600 of 10,137. Elapsed: 0:45:37.
Batch 6,000 of 10,137. Elapsed: 0:48:55.
Batch 6,400 of 10,137. Elapsed: 0:52:27.
Batch 6,800 of 10,137. Elapsed: 0:56:11.
Batch 7,200 of 10,137. Elapsed: 1:00:02.
Batch 7,600 of 10,137. Elapsed: 1:03:15.
Batch 8,000 of 10,137. Elapsed: 1:06:45.
Batch 8,400 of 10,137. Elapsed: 1:10:21.
Batch 8,800 of 10,137. Elapsed: 1:14:00.
Batch 9,200 of 10,137. Elapsed: 1:17:34.
Batch 9,600 of 10,137. Elapsed: 1:20:38.
Batch 10,000 of 10,137. Elapsed: 1:24:19.
平均训练损失 loss: 2.20
训练时间: 1:25:24

平均测试损失 Loss: 10.32
测试时间: 0:09:32
Epoch 2 / 2
Batch 400 of 10,137. Elapsed: 0:03:05.
Batch 800 of 10,137. Elapsed: 0:06:38.
Batch 1,200 of 10,137. Elapsed: 0:10:13.
Batch 1,600 of 10,137. Elapsed: 0:13:43.
Batch 2,000 of 10,137. Elapsed: 0:17:17.
Batch 2,400 of 10,137. Elapsed: 0:20:23.
Batch 2,800 of 10,137. Elapsed: 0:23:57.
Batch 3,200 of 10,137. Elapsed: 0:27:31.
Batch 3,600 of 10,137. Elapsed: 0:31:13.
Batch 4,000 of 10,137. Elapsed: 0:34:53.
Batch 4,400 of 10,137. Elapsed: 0:38:35.
Batch 4,800 of 10,137. Elapsed: 0:42:10.
Batch 5,200 of 10,137. Elapsed: 0:45:12.
Batch 5,600 of 10,137. Elapsed: 0:48:46.
Batch 6,000 of 10,137. Elapsed: 0:52:18.
Batch 6,400 of 10,137. Elapsed: 0:55:53.
Batch 6,800 of 10,137. Elapsed: 0:59:29.
Batch 7,200 of 10,137. Elapsed: 1:02:33.
Batch 7,600 of 10,137. Elapsed: 1:06:07.
Batch 8,000 of 10,137. Elapsed: 1:09:41.
Batch 8,400 of 10,137. Elapsed: 1:13:13.
Batch 8,800 of 10,137. Elapsed: 1:16:54.
Batch 9,200 of 10,137. Elapsed: 1:20:38.
Batch 9,600 of 10,137. Elapsed: 1:24:29.
Batch 10,000 of 10,137. Elapsed: 1:28:17.
平均训练损失 loss: 1.36
训练时间: 1:29:22

平均测试损失 Loss: 10.24
测试时间: 0:09:05
训练一共用了 3:13:23 (h:mm:ss)

测试一下

model.load_state_dict(torch.load('roberta_models/pytorch_model.bin'))
model.to(device)
<All keys matched successfully>
context = "株洲北站全称广州铁路(集团)公司株洲北火车站。除站场主体,另外管辖湘潭站、湘潭东站和三个卫星站,田心站、白马垅站、十里冲站,以及原株洲车站货房。车站办理编组、客运、货运业务。车站机关地址:湖南省株洲市石峰区北站路236号,邮编412001。株洲北站位于湖南省株洲市区东北部,地处中南路网,是京广铁路、沪昆铁路两大铁路干线的交汇处,属双向纵列式三级七场路网性编组站。车站等级为特等站,按技术作业性质为编组站,按业务性质为客货运站,是株洲铁路枢纽的主要组成部分,主要办理京广、沪昆两大干线四个方向货物列车的到发、解编作业以及各方向旅客列车的通过作业。每天办理大量的中转车流作业,并有大量的本地车流产生和集散,在路网车流的组织中占有十分重要的地位,是沟通华东、华南、西南和北方的交通要道,任务艰巨,作业繁忙。此外,株洲北站还有连接石峰区喻家坪工业站的专用线。株洲北站的前身是田心车站。"
qestion = "株洲北站的机关地址是什么"

inputs = tokenizer(context, qestion, return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
start_scores, end_scores = model(**inputs)
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
answer =tokens[answer_start:answer_end]
str=""
print(str.join(answer))
print("标准答案:湖南省株洲市石峰区北站路236号,邮编412001。" )
车站机关地址:湖南省株洲市石峰区北站路236
标准答案:湖南省株洲市石峰区北站路236号,邮编412001。
context = "地方税务局是一个泛称,是中华人民共和国1994年分税制改革的结果。1994年分税制把税种分为中央税、地方税、中央地方共享税;把征税系统由税务局分为国家税务系统与地方税务系统。其中中央税、中央地方共享税由国税系统(包括国家税务总局及各地的国家税务局)征收,地方税由地方税务局征收。地方税务局在省、市、县、区各级地方政府中设置,国务院中没有地方税务局。地税局长由本级人民政府任免,但要征求上级国家税务局的意见。一般情况下,地方税务局与财政厅(局)是分立的,不是一个机构两块牌子。但也有例外,例如,上海市在2008年政府机构改革之前,上海市财政局、上海市地方税务局和上海市国家税务局为合署办公,一个机构、三块牌子,而2008年政府机构改革之后,上海市财政局被独立设置,上海市地方税务局和上海市国家税务局仍为合署办公,一个机构、两块牌子。同时县一级,财政局长常常兼任地税局长。地方税务局主要征收:营业税、企业所得税、个人所得税、土地增值税、城镇土地使用税、城市维护建设税、房产税、城市房地产税、车船使用税、车辆使用牌照税、屠宰税、资源税、固定资产投资方向调节税、印花税、农业税、农业特产税、契税、耕地占用税、筵席税,城市集体服务事业费、文化事业建设费、教育费附加以及地方税的滞补罚收入和外商投资企业土地使用费。"
qestion = "地方税务局是中华人民共和国哪一年分税制改革的结果"

inputs = tokenizer(context, qestion, return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
start_scores, end_scores = model(**inputs)
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
answer =tokens[answer_start:answer_end]
str=""
print(str.join(answer))
print("标准答案:19" )
不是一个机构两块牌
标准答案:19
context = "萤火虫工作室是一家总部设在英国伦敦和康涅狄格州坎顿,并在苏格兰阿伯丁设有质量部门的电子游戏开发商。1999年8月,西蒙·布雷德伯里,埃里克·乌列特和大卫·莱斯特成立萤火虫工作室,一起开发了很多游戏,包括非常成功的“凯撒” 和“王国霸主”系列。公司成立后,萤火虫工作室发布了一个未来前景规划:\"“萤火虫工作室要创造一个人们游戏其中的引人瞩目的新世界。我们要提供一个丰富多彩的游戏环境,令玩家在我们的图像和编码技术不断提升的游戏世界中感到愉快。我们的专长是在游戏中开发战略,而我们今后要继续发展,与我们精彩的视觉效果,引人瞩目的人物和易于上手的特点相结合。如果我们能这样完成工作,玩家将会发现一个自己创造的,加进自己个性的世界”\"。该公司将市场定位于PC(Windows)和苹果电脑上的即时战略游戏领域,特别是公司成功的“要塞”系列。目前,他们正在开发PC和Xbox360上的次时代游戏。"
qestion = "萤火虫工作室的总部设在哪里"

inputs = tokenizer(context, qestion, return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
start_scores, end_scores = model(**inputs)
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
answer =tokens[answer_start:answer_end]
str=""
print(str.join(answer))
print("标准答案:英国伦敦和康涅狄格州坎顿。" )
英国伦敦和康涅狄格州坎
标准答案:英国伦敦和康涅狄格州坎顿。

由此可见 我们的模型只是训练了2个epoch,模型答案就和标准答案十分接近了


 类似资料: