当前位置: 首页 > 工具软件 > mushroom > 使用案例 >

xgboost 毒蘑菇mushroom数据集分类

太叔俊侠
2023-12-01

安装 xgboost

pip3 install xgboost

毒蘑菇数据集

毒蘑菇数据集的描述参考:
https://archive.ics.uci.edu/ml/datasets/Mushroom
毒蘑菇的特征描述如下

Attribute Information:

1. cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s 
2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s 
3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r, pink=p,purple=u,red=e,white=w,yellow=y 
4. bruises?: bruises=t,no=f 
5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s 
6. gill-attachment: attached=a,descending=d,free=f,notched=n 
7. gill-spacing: close=c,crowded=w,distant=d 
8. gill-size: broad=b,narrow=n 
9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e, white=w,yellow=y 
10. stalk-shape: enlarging=e,tapering=t 
11. stalk-root: bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,missing=? 
12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s 
13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s 
14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y 
15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y 
16. veil-type: partial=p,universal=u 
17. veil-color: brown=n,orange=o,white=w,yellow=y 
18. ring-number: none=n,one=o,two=t 
19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l, none=n,pendant=p,sheathing=s,zone=z 
20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r, orange=o,purple=u,white=w,yellow=y 
21. population: abundant=a,clustered=c,numerous=n, scattered=s,several=v,solitary=y 
22. habitat: grasses=g,leaves=l,meadows=m,paths=p, urban=u,waste=w,woods=d

mushroom数据集原有的格式是:

p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g

第一列的对应关系是:无毒 edible=e, 有毒 poisonous=p,后面的列是特征。
显然,上述特征值都是categorical 特征,跟上面22个特征描述一一对应。对于这种categorical feature,一般都要进行onehot编码,可以借助sklearn的DictVectorizer或者自己编写onehot,然后按照xgboost的输入格式 :
category feature_id:feature_value...

例如xgboost tutorial里面内置的mushroom数据集:

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1

由于xgboost的repo已经内置了mushroom的数据集,并且格式已经调整好,所以就直接用吧(懒~)

分类验证

先把 https://github.com/dmlc/xgboost 这个repo clone下来,demo/data里面有处理好的mushroom数据集
下面跑一下交叉验证的代码

#!/usr/bin/python
import numpy as np
import xgboost as xgb
import json

### load data in do training
dtrain = xgb.DMatrix('data/agaricus.txt.train')
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
num_round = 2

print('running cross validation')
# do cross validation, this will print result out as
# [iteration]  metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
       metrics={'error'}, seed=0,
       callbacks=[xgb.callback.print_evaluation(show_stdv=True)])

print('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as
# [iteration]  metric_name:mean_value
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
             metrics={'error',"auc","rmse"}, seed=0,
             callbacks=[xgb.callback.print_evaluation(show_stdv=False),
                        xgb.callback.early_stop(3)])
print(json.dumps(res, indent=4))

评价的几个维度,分别是error, auc, rmse,输入如下:

[22:46:13] 6513x127 matrix with 143286 entries loaded from ../data/agaricus.txt.train
running cross validation
[0]	train-error:0.0506682+0.009201	test-error:0.0557316+0.0158887
[1]	train-error:0.0213034+0.00205561	test-error:0.0211884+0.00365323
running cross validation, disable standard deviation display
[0]	train-auc:0.962827	train-error:0.0506682	train-rmse:0.230558	test-auc:0.960993	test-error:0.0557316	test-rmse:0.234907
Multiple eval metrics have been passed: 'test-rmse' will be used for early stopping.

Will train until test-rmse hasn't improved in 3 rounds.
[1]	train-auc:0.984829	train-error:0.0213034	train-rmse:0.159816	test-auc:0.984753	test-error:0.0211884	test-rmse:0.159676
[2]	train-auc:0.99763	train-error:0.0099418	train-rmse:0.111782	test-auc:0.997216	test-error:0.0099786	test-rmse:0.113232
[3]	train-auc:0.998845	train-error:0.0141256	train-rmse:0.104002	test-auc:0.998575	test-error:0.0144336	test-rmse:0.105863
[4]	train-auc:0.999404	train-error:0.0059878	train-rmse:0.078452	test-auc:0.999001	test-error:0.0062948	test-rmse:0.0814638
[5]	train-auc:0.999571	train-error:0.0020344	train-rmse:0.0554116	test-auc:0.999236	test-error:0.0016886	test-rmse:0.0549618
[6]	train-auc:0.999643	train-error:0.0012284	train-rmse:0.0442974	test-auc:0.999389	test-error:0.001228	test-rmse:0.0447266
[7]	train-auc:0.999736	train-error:0.0012284	train-rmse:0.0409082	test-auc:0.999535	test-error:0.001228	test-rmse:0.0408704
[8]	train-auc:0.999967	train-error:0.0009212	train-rmse:0.0325856	test-auc:0.99992	test-error:0.001228	test-rmse:0.0378632
[9]	train-auc:0.999982	train-error:0.0006142	train-rmse:0.0305786	test-auc:0.999959	test-error:0.001228	test-rmse:0.0355032
{
    "train-auc-mean": [
        0.9628270000000001,
        0.9848285999999999,
        0.9976303999999999,
        0.9988453999999999,
        0.9994036000000002,
        0.9995708000000001,
        0.9996430000000001,
        0.9997361999999999,
        0.9999673999999998,
        0.9999824
    ],
    "train-auc-std": [
        0.008448006628785306,
        0.006803294807664875,
        0.0009387171245907945,
        0.0003538839357755705,
        0.0002618454506001579,
        0.00025375531521528084,
        0.00020217319307960657,
        0.00020191027710344742,
        3.975726348731663e-05,
        2.0401960690100178e-05
    ],
    "train-error-mean": [
        0.0506682,
        0.0213034,
        0.009941799999999999,
        0.014125599999999999,
        0.0059878,
        0.0020344,
        0.0012284,
        0.0012284,
        0.0009212,
        0.0006142
    ],
    "train-error-std": [
        0.009200997193782855,
        0.0020556122786167634,
        0.006076479256938181,
        0.0017057689878761427,
        0.0018779069625516596,
        0.001469605198684327,
        0.00026026494193417596,
        0.00026026494193417596,
        0.0005061973528180487,
        0.000506318634853587
    ],
    "train-rmse-mean": [
        0.2305582,
        0.1598156,
        0.11178239999999999,
        0.104002,
        0.07845200000000001,
        0.05541159999999999,
        0.0442974,
        0.04090819999999999,
        0.0325856,
        0.0305786
    ],
    "train-rmse-std": [
        0.0037922719522734682,
        0.01125541159798254,
        0.002488777016930202,
        0.004049203625405865,
        0.0013676392799272755,
        0.0037808676041353262,
        0.0029096398127603355,
        0.002357843285716845,
        0.006475110519520113,
        0.005710078619423729
    ],
    "test-auc-mean": [
        0.9609932000000001,
        0.9847534,
        0.9972156,
        0.9985745999999999,
        0.9990005999999999,
        0.9992364,
        0.9993892000000001,
        0.9995353999999999,
        0.9999202,
        0.9999590000000002
    ],
    "test-auc-std": [
        0.01004388481415432,
        0.0073321434137637925,
        0.001719973674217141,
        0.0008885571675474929,
        0.0007358270448957148,
        0.0007796844489920071,
        0.0006559303011753682,
        0.00040290524940736174,
        9.667760857612425e-05,
        4.69212105555625e-05
    ],
    "test-error-mean": [
        0.055731600000000006,
        0.021188400000000003,
        0.009978599999999999,
        0.0144336,
        0.006294800000000001,
        0.0016885999999999997,
        0.001228,
        0.001228,
        0.001228,
        0.001228
    ],
    "test-error-std": [
        0.015888666194492227,
        0.0036532266614597024,
        0.004827953421482027,
        0.003517125508138713,
        0.0031231752688570006,
        0.0005741844999649501,
        0.0010409403441119958,
        0.0010409403441119958,
        0.0010409403441119958,
        0.0010409403441119958
    ],
    "test-rmse-mean": [
        0.2349068,
        0.1596758,
        0.113232,
        0.10586340000000001,
        0.08146379999999999,
        0.054961800000000005,
        0.04472659999999999,
        0.040870399999999994,
        0.0378632,
        0.0355032
    ],
    "test-rmse-std": [
        0.009347684363520205,
        0.013838324138420807,
        0.006189637210693368,
        0.008622184470306812,
        0.007872053592297248,
        0.009846015506792583,
        0.010218353460318349,
        0.0130493180756697,
        0.013843686595701305,
        0.013998153155327313
    ]
}

看rmse效果,说明xgboost的分类效果还是不错的。

 类似资料: