当前位置: 首页 > 工具软件 > AllenNLP > 使用案例 >

【AllenNLP】: 自定义predictor—输入文本输出中文

巩衡
2023-12-01

说明

我们在使用AllenNLP的时候,当使用自定义predictor的时候,默认的是输入json,我们可以修改为输入以行为单位的文本格式;

另外默认的输出是json,我们也可以自定义修改为文本,特别是在json.dumps的时候中文会默认是ASCII码,我们自定义的时候可以设置为False来输出中文字符;

另外默认的输出只有label,没有input_text作为参考,我们可以在outputs中新增,来方便地查看预测输出:

from copy import deepcopy
from typing import List, Dict

from overrides import overrides
import numpy
import json
from allennlp.common.util import JsonDict, sanitize

from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField
from allennlp.data.tokenizers import CharacterTokenizer


@Predictor.register("cnews_text_classifier")
class TextClassifierPredictor(Predictor):
    """
    Predictor for any model that takes in a sentence and returns
    a single class for it.  In particular, it can be used with
    the [`BasicClassifier`](../models/basic_classifier.md) model.
    Registered as a `Predictor` with name "text_classifier".
    """
    def __init__(self, model, dataset_reader):
        super(TextClassifierPredictor, self).__init__(model, dataset_reader)
        self.input_text = ""

    def predict(self, sentence: str) -> JsonDict:
        return self.predict_json({"sentence": sentence})

    @overrides
    def load_line(self, line: str) -> JsonDict:
        """
        如果你不想输入为json格式,可以可以@overrides这个函数
        """
        return {"text": line}

    @overrides
    def dump_line(self, outputs: JsonDict) -> str:
        """
        如果你不想输出json格式,可以@overrides这个函数
        """
        return json.dumps(outputs, ensure_ascii=False) + "\n"
    
    @overrides
    def predict_instance(self, instance: Instance) -> JsonDict:
        outputs = self._model.forward_on_instance(instance)
        outputs["input_text"] = self.input_text
        
        return sanitize(outputs)

    @overrides
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        """
        Expects JSON that looks like `{"sentence": "..."}`.
        Runs the underlying model, and adds the `"label"` to the output.
        """
        sentence = json_dict["text"]
        self.input_text = json_dict["text"]
        return self._dataset_reader.text_to_instance(sentence)

    @overrides
    def predictions_to_labeled_instances(self, instance: Instance, outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
        new_instance = deepcopy(instance)
        label = numpy.argmax(outputs["probs"])
        new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
        return [new_instance]
 类似资料: