当前位置: 首页 > 知识库问答 >
问题:

Sigkit学习模型持久性:泡菜vs pmml vs...?

通啸
2023-03-14

我构建了一个scikit学习模型,并希望在日常python cron作业中重用它(注意:不涉及其他平台-没有R,没有Java)

我对它进行了pickle(实际上,我对自己的对象进行了pickle,其中一个字段是GradientBoostingClassifier),然后在cron作业中取消了pickle。到目前为止还不错(在scikit learn中将分类器保存到磁盘和在scikit learn中将模型持久化中进行了讨论?)。

但是,我升级了skLearning,现在我得到了这些警告:

.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator DecisionTreeRegressor from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator PriorProbabilityEstimator from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
.../.local/lib/python2.7/site-packages/sklearn/base.py:315: 
UserWarning: Trying to unpickle estimator GradientBoostingClassifier from version 0.18.1 when using version 0.18.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)

我现在该怎么办?

>

我可以解开文件并重新腌制它。这在0.18.2时有效,但在0.19时中断。joblib看起来也好不到哪去。

我希望我可以将数据保存为独立于版本的ASCII格式(例如,JSON或XML)。显然,这是最佳的解决方案,但似乎没有办法做到这一点(另请参见Sklearn-model persistence without-pkl file)。

我可以将模型保存到PMML,但它的支持充其量也不温不火:我可以使用sklearn2pmml保存模型(尽管不容易),并使用augustus/lightpmmlpredictor应用(尽管不加载)模型。但是,pip无法直接使用这些工具,这使得部署成为一场噩梦。还有,奥古斯都的

上述方法的一种变体:使用sklearn2pmml保存PMML,并使用openscoring进行评分。需要与外部过程接口。大笑。

建议?

共有1个答案

丘华翰
2023-03-14

通常情况下,不可能跨不同版本的Scikit学习实现模型持久性。原因很明显:您用一个定义来泡制Class1,并希望用另一个定义将其解泡制到Class2中。

你可以:

  • 仍然尝试坚持使用sklearn的一个版本
  • 忽略这些警告,并希望对Class1有效的方法也适用于Class2
  • 编写您自己的类,该类可以序列化您的GradientBoostingClassifier,并从该序列化表单还原它,希望它比pickle工作得更好

我举了一个例子,说明了如何将单个DecisionTreeRegressor转换为完全兼容JSON的纯列表和dict格式,并将其还原回来。

import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_classification

### Code to serialize and deserialize trees

LEAF_ATTRIBUTES = ['children_left', 'children_right', 'threshold', 'value', 'feature', 'impurity', 'weighted_n_node_samples']
TREE_ATTRIBUTES = ['n_classes_', 'n_features_', 'n_outputs_']

def serialize_tree(tree):
    """ Convert a sklearn.tree.DecisionTreeRegressor into a json-compatible format """
    encoded = {
        'nodes': {},
        'tree': {},
        'n_leaves': len(tree.tree_.threshold),
        'params': tree.get_params()
    }
    for attr in LEAF_ATTRIBUTES:
        encoded['nodes'][attr] = getattr(tree.tree_, attr).tolist()
    for attr in TREE_ATTRIBUTES:
        encoded['tree'][attr] = getattr(tree, attr)
    return encoded

def deserialize_tree(encoded):
    """ Restore a sklearn.tree.DecisionTreeRegressor from a json-compatible format """
    x = np.arange(encoded['n_leaves'])
    tree = DecisionTreeRegressor().fit(x.reshape((-1,1)), x)
    tree.set_params(**encoded['params'])
    for attr in LEAF_ATTRIBUTES:
        for i in range(encoded['n_leaves']):
            getattr(tree.tree_, attr)[i] = encoded['nodes'][attr][i]
    for attr in TREE_ATTRIBUTES:
        setattr(tree, attr, encoded['tree'][attr])
    return tree

## test the code

X, y = make_classification(n_classes=3, n_informative=10)
tree = DecisionTreeRegressor().fit(X, y)
encoded = serialize_tree(tree)
decoded = deserialize_tree(encoded)
assert (decoded.predict(X)==tree.predict(X)).all()

这样,您就可以继续序列化和反序列化整个GradientBoostingClassifier

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble.gradient_boosting import PriorProbabilityEstimator

def serialize_gbc(clf):
    encoded = {
        'classes_': clf.classes_.tolist(),
        'max_features_': clf.max_features_, 
        'n_classes_': clf.n_classes_,
        'n_features_': clf.n_features_,
        'train_score_': clf.train_score_.tolist(),
        'params': clf.get_params(),
        'estimators_shape': list(clf.estimators_.shape),
        'estimators': [],
        'priors':clf.init_.priors.tolist()
    }
    for tree in clf.estimators_.reshape((-1,)):
        encoded['estimators'].append(serialize_tree(tree))
    return encoded

def deserialize_gbc(encoded):
    x = np.array(encoded['classes_'])
    clf = GradientBoostingClassifier(**encoded['params']).fit(x.reshape(-1, 1), x)
    trees = [deserialize_tree(tree) for tree in encoded['estimators']]
    clf.estimators_ = np.array(trees).reshape(encoded['estimators_shape'])
    clf.init_ = PriorProbabilityEstimator()
    clf.init_.priors = np.array(encoded['priors'])
    clf.classes_ = np.array(encoded['classes_'])
    clf.train_score_ = np.array(encoded['train_score_'])
    clf.max_features_ = encoded['max_features_']
    clf.n_classes_ = encoded['n_classes_']
    clf.n_features_ = encoded['n_features_']
    return clf

# test on the same problem
clf = GradientBoostingClassifier()
clf.fit(X, y);
encoded = serialize_gbc(clf)
decoded = deserialize_gbc(encoded)
assert (decoded.predict(X) == clf.predict(X)).all()

这适用于scikit学习v0。19,但不要问我在下一个版本中会出现什么来破坏这个代码。我既不是先知,也不是sklearn的开发者。

如果希望完全独立于新版本的sklearn,最安全的方法是编写一个函数,遍历序列化树并进行预测,而不是重新创建sklearn树。

 类似资料:
  • 快住手!域模型不是持久性模型 如果这是真的,将持久性对象与域对象分开有什么好处呢?

  • 我是Spring框架的新手,并试图为sql数据库创建动态搜索查询-类似于下面的线程中描述的内容。 使用spring数据jpa和spring mvc筛选数据库行 这个线程提供了一个有用的指南,但是我在生成所描述的元模型文件时遇到了问题。我知道这是一个常见的问题,我曾尝试实施网上已有的解决方案,但它们都不起作用。 我对持久性的概念有着特殊的理解。xml文件以及在哪里找到它。据我所知,它的功能似乎与我的

  • 校验者: @why2lyj(Snow Wang) @小瑶 翻译者: @那伊抹微笑 在训练完 scikit-learn 模型之后,最好有一种方法来将模型持久化以备将来使用,而无需重新训练。 以下部分为您提供了有关如何使用 pickle 来持久化模型的示例。 在使用 pickle 序列化时,我们还将回顾一些安全性和可维护性方面的问题。 3.4.1. 持久化示例 可以通过使用 Python 的内置持久化

  • 决策树(decision tree)是一种基本的分类与回归方法。决策树模型呈树形结构。通常包含3个步骤:特征选择、决策树的生成和决策树的修剪。 决策树模型 分类决策树树模型是一种描述对实例进行分类的树形结构。决策树由节点(node)和有向边(directed edge)组成。节点有两种类型:内部节点(internal node)和叶节点。内部节点表示一个特征或属性,叶节点表示一个类。 用决策树分类

  • 英文原文:http://emberjs.com/guides/models/persisting-records/ Ember Data中的记录都基于实例来进行持久化。调用DS.Model实例的save()会触发一个网络请求,来进行记录的持久化。 下面是几个示例: 1 2 3 4 5 6 var post = store.createRecord('post', { title: 'Rail

  • 我正在寻找从经典Akka持久化迁移到Akka持久化类型。在这里找到的Lagom留档:1说“注意:从Lagom持久化(经典)迁移到Akka持久化类型时的唯一限制是需要完全关闭集群。即使所有持久数据都是兼容的,Lagom持久化(经典)和Akka持久化类型也不能共存。” 有人知道这是否适用于服务器可能知道的所有持久实体吗?例如,我使用的服务有3个独立的持久实体。我需要一次迁移所有3个,还是可以一次迁移一

  • EJB 3.0,EJB 2.0中使用的实体bean在很大程度上被持久性机制所取代。 现在,实体bean是一个简单的POJO,它具有与表的映射。 以下是持久性API中的关键角色 - Entity - 表示数据存储记录的持久对象。 可序列化是件好事。 EntityManager - 持久性接口,用于对持久对象(实体)执行添加/删除/更新/查找等数据操作。 它还有助于使用Query接口执行查询。 Per

  • 问题内容: 我正在Java中寻找一个持久的哈希结构,这是一个简单的键值存储,其中key是唯一的字符串,value是一个int。每次将现有密钥添加到存储中时,密钥的值将增加。 我需要它很大-可能有5亿-10亿个密钥。我一直在评估tokyo-cabinet http://fallabs.com/tokyocabinet/javadoc/,但不确定其扩展性如何- 随着哈希值的增加,插入时间似乎越来越长。