当前位置: 首页 > 工具软件 > AllenNLP > 使用案例 >

【AllenNLP】: 自定义predictor—输入文本输出中文






from copy import deepcopy
from typing import List, Dict

from overrides import overrides
import numpy
import json
from allennlp.common.util import JsonDict, sanitize

from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField
from allennlp.data.tokenizers import CharacterTokenizer

class TextClassifierPredictor(Predictor):
    Predictor for any model that takes in a sentence and returns
    a single class for it.  In particular, it can be used with
    the [`BasicClassifier`](../models/basic_classifier.md) model.
    Registered as a `Predictor` with name "text_classifier".
    def __init__(self, model, dataset_reader):
        super(TextClassifierPredictor, self).__init__(model, dataset_reader)
        self.input_text = ""

    def predict(self, sentence: str) -> JsonDict:
        return self.predict_json({"sentence": sentence})

    def load_line(self, line: str) -> JsonDict:
        return {"text": line}

    def dump_line(self, outputs: JsonDict) -> str:
        return json.dumps(outputs, ensure_ascii=False) + "\n"
    def predict_instance(self, instance: Instance) -> JsonDict:
        outputs = self._model.forward_on_instance(instance)
        outputs["input_text"] = self.input_text
        return sanitize(outputs)

    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        Expects JSON that looks like `{"sentence": "..."}`.
        Runs the underlying model, and adds the `"label"` to the output.
        sentence = json_dict["text"]
        self.input_text = json_dict["text"]
        return self._dataset_reader.text_to_instance(sentence)

    def predictions_to_labeled_instances(self, instance: Instance, outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
        new_instance = deepcopy(instance)
        label = numpy.argmax(outputs["probs"])
        new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
        return [new_instance]