我们在使用AllenNLP的时候,当使用自定义predictor
的时候,默认的是输入json
,我们可以修改为输入以行为单位的文本格式;
另外默认的输出是json
,我们也可以自定义修改为文本,特别是在json.dumps
的时候中文会默认是ASCII
码,我们自定义的时候可以设置为False
来输出中文字符;
另外默认的输出只有label,没有input_text
作为参考,我们可以在outputs
中新增,来方便地查看预测输出:
from copy import deepcopy
from typing import List, Dict
from overrides import overrides
import numpy
import json
from allennlp.common.util import JsonDict, sanitize
from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField
from allennlp.data.tokenizers import CharacterTokenizer
@Predictor.register("cnews_text_classifier")
class TextClassifierPredictor(Predictor):
"""
Predictor for any model that takes in a sentence and returns
a single class for it. In particular, it can be used with
the [`BasicClassifier`](../models/basic_classifier.md) model.
Registered as a `Predictor` with name "text_classifier".
"""
def __init__(self, model, dataset_reader):
super(TextClassifierPredictor, self).__init__(model, dataset_reader)
self.input_text = ""
def predict(self, sentence: str) -> JsonDict:
return self.predict_json({"sentence": sentence})
@overrides
def load_line(self, line: str) -> JsonDict:
"""
如果你不想输入为json格式,可以可以@overrides这个函数
"""
return {"text": line}
@overrides
def dump_line(self, outputs: JsonDict) -> str:
"""
如果你不想输出json格式,可以@overrides这个函数
"""
return json.dumps(outputs, ensure_ascii=False) + "\n"
@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
outputs = self._model.forward_on_instance(instance)
outputs["input_text"] = self.input_text
return sanitize(outputs)
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Expects JSON that looks like `{"sentence": "..."}`.
Runs the underlying model, and adds the `"label"` to the output.
"""
sentence = json_dict["text"]
self.input_text = json_dict["text"]
return self._dataset_reader.text_to_instance(sentence)
@overrides
def predictions_to_labeled_instances(self, instance: Instance, outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
new_instance = deepcopy(instance)
label = numpy.argmax(outputs["probs"])
new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
return [new_instance]