以下代码用于计算音译、语音识别的度量,代码使用了jiwer pipy包,可以直接通过pip安装。
该第三方库提供了包括Character Error Rate (CER), Word Error Rate (WER), Match Error Rate (MER), Word Information Lost (WIL) and Word Information Preserved (WIP)在内的5种度量方法(实际上只有四种,因为WIL和WIP是互补的)。
关于这几种metric的详细解释,参见paper:
From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition
import jiwer
def compute_single_metric(gt,pred,metric):
if metric == 'cer':
return jiwer.cer(gt,pred)
elif metric == 'wer':
return jiwer.wer(gt,pred)
elif metric == 'mer':
return jiwer.mer(gt,pred)
elif metric == 'wil':
return jiwer.wil(gt,pred)
elif metric == 'wip':
return jiwer.wip(gt,pred)
else:
raise KeyError("invalid metric: {} !".format(metric))
def compute_metrics(ground_truth:list,prediction:list,metrics:list)->dict:
"""compute the auto speech recognition (ASR) metrics, inlcuding:
Character Error Rate (CER),
Word Error Rate (WER),
Match Error Rate (MER),
Word Information Lost (WIL) and Word Information Preserved (WIP)
Args:
ground_truth (list): list of ground truth answer, e.g., ['apple','marry','mark twin']
prediction (list): list of the prediction, e.g., ['appl','malli','mark twen']
metrics (list): list of choices, i.e., ['cer','wer','mer','wil','wip']
"""
choices = ['cer','wer','mer','wil','wip']
assert len(ground_truth) == len(prediction), 'length mis-match!'
assert all([c in choices for c in metrics]), "metrics out of the pre-definition, i.e., ['cer','wer','mer','wil','wip']"
results = dict([(c,0.0) for c in metrics])
## calculate the average value from all instances, traverse each metric
for metric in metrics:
score = compute_single_metric(ground_truth,prediction,metric)
score = score * 100
results[metric] = score
return results
if __name__ == "__main__":
ground_truth = ["hello world", "i like monthy python"]
hypothesis = ["hello duck", "i like python"]
metrics_1 = ['cer','wer','mer','wil','wip']
metrics_2 = []
metrics_3 = ['cer']
metrics_4 = ['ccc']
print(compute_metrics(ground_truth,hypothesis,metrics_1))
print(compute_metrics(ground_truth,hypothesis,metrics_2))
print(compute_metrics(ground_truth,hypothesis,metrics_3))
print(compute_metrics(ground_truth,hypothesis,metrics_4))
另外,下面这几个repositories也是用来计算音译度量的,只不过没有jiwer
全: