注意class_weight 一定是一个字典, 不然虽然不会报错,但是是没有效果的,loss完全不会变。
sklearn 直接提供了一个函数来计算类别权重:
# 计算类别权重
my_class_weight = class_weight.compute_class_weight('balanced'
,np.unique(train_Y)
,train_Y).tolist()
# 需要转成字典
class_weight_dict = dict(zip([x for x in np.unique(train_Y)], my_class_weight))
但是返回值也需要转换为字典。