当前位置: 首页 > 工具软件 > mushroom > 使用案例 >

xgboost python分类_xgboost 毒蘑菇mushroom数据集分类

司空丰
2023-12-01

安装 xgboost

pip3 install xgboost

毒蘑菇数据集

Attribute Information:

1. cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s

2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s

3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r, pink=p,purple=u,red=e,white=w,yellow=y

4. bruises?: bruises=t,no=f

5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s

6. gill-attachment: attached=a,descending=d,free=f,notched=n

7. gill-spacing: close=c,crowded=w,distant=d

8. gill-size: broad=b,narrow=n

9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e, white=w,yellow=y

10. stalk-shape: enlarging=e,tapering=t

11. stalk-root: bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,missing=?

12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s

13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s

14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y

15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y

16. veil-type: partial=p,universal=u

17. veil-color: brown=n,orange=o,white=w,yellow=y

18. ring-number: none=n,one=o,two=t

19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l, none=n,pendant=p,sheathing=s,zone=z

20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r, orange=o,purple=u,white=w,yellow=y

21. population: abundant=a,clustered=c,numerous=n, scattered=s,several=v,solitary=y

22. habitat: grasses=g,leaves=l,meadows=m,paths=p, urban=u,waste=w,woods=d1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

mushroom数据集原有的格式是:

p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u

e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g1

2

第一列的对应关系是:无毒 edible=e, 有毒 poisonous=p,后面的列是特征。

显然,上述特征值都是categorical 特征,跟上面22个特征描述一一对应。对于这种categorical feature,一般都要进行onehot编码,可以借助sklearn的DictVectorizer或者自己编写onehot,然后按照xgboost的输入格式 :

category feature_id:feature_value...

例如xgboost tutorial里面内置的mushroom数据集:

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1

0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:11

2

由于xgboost的repo已经内置了mushroom的数据集,并且格式已经调整好,所以就直接用吧(懒~)

分类验证

先把 https://github.com/dmlc/xgboost 这个repo clone下来,demo/data里面有处理好的mushroom数据集

下面跑一下交叉验证的代码

#!/usr/bin/python

import numpy as np

import xgboost as xgb

import json

### load data in do training

dtrain = xgb.DMatrix('data/agaricus.txt.train')

param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}

num_round = 2

print('running cross validation')

# do cross validation, this will print result out as

# [iteration] metric_name:mean_value+std_value

# std_value is standard deviation of the metric

xgb.cv(param, dtrain, num_round, nfold=5,

metrics={'error'}, seed=0,

callbacks=[xgb.callback.print_evaluation(show_stdv=True)])

print('running cross validation, disable standard deviation display')

# do cross validation, this will print result out as

# [iteration] metric_name:mean_value

res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,

metrics={'error',"auc","rmse"}, seed=0,

callbacks=[xgb.callback.print_evaluation(show_stdv=False),

xgb.callback.early_stop(3)])

print(json.dumps(res, indent=4))1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

评价的几个维度,分别是error, auc, rmse,输入如下:

[22:46:13] 6513x127 matrix with 143286 entries loaded from ../data/agaricus.txt.train

running cross validation

[0]train-error:0.0506682+0.009201test-error:0.0557316+0.0158887

[1]train-error:0.0213034+0.00205561test-error:0.0211884+0.00365323

running cross validation, disable standard deviation display

[0]train-auc:0.962827train-error:0.0506682train-rmse:0.230558test-auc:0.960993test-error:0.0557316test-rmse:0.234907

Multiple eval metrics have been passed: 'test-rmse' will be used for early stopping.

Will train until test-rmse hasn't improved in 3 rounds.

[1]train-auc:0.984829train-error:0.0213034train-rmse:0.159816test-auc:0.984753test-error:0.0211884test-rmse:0.159676

[2]train-auc:0.99763train-error:0.0099418train-rmse:0.111782test-auc:0.997216test-error:0.0099786test-rmse:0.113232

[3]train-auc:0.998845train-error:0.0141256train-rmse:0.104002test-auc:0.998575test-error:0.0144336test-rmse:0.105863

[4]train-auc:0.999404train-error:0.0059878train-rmse:0.078452test-auc:0.999001test-error:0.0062948test-rmse:0.0814638

[5]train-auc:0.999571train-error:0.0020344train-rmse:0.0554116test-auc:0.999236test-error:0.0016886test-rmse:0.0549618

[6]train-auc:0.999643train-error:0.0012284train-rmse:0.0442974test-auc:0.999389test-error:0.001228test-rmse:0.0447266

[7]train-auc:0.999736train-error:0.0012284train-rmse:0.0409082test-auc:0.999535test-error:0.001228test-rmse:0.0408704

[8]train-auc:0.999967train-error:0.0009212train-rmse:0.0325856test-auc:0.99992test-error:0.001228test-rmse:0.0378632

[9]train-auc:0.999982train-error:0.0006142train-rmse:0.0305786test-auc:0.999959test-error:0.001228test-rmse:0.0355032

{

"train-auc-mean": [

0.9628270000000001,

0.9848285999999999,

0.9976303999999999,

0.9988453999999999,

0.9994036000000002,

0.9995708000000001,

0.9996430000000001,

0.9997361999999999,

0.9999673999999998,

0.9999824

],

"train-auc-std": [

0.008448006628785306,

0.006803294807664875,

0.0009387171245907945,

0.0003538839357755705,

0.0002618454506001579,

0.00025375531521528084,

0.00020217319307960657,

0.00020191027710344742,

3.975726348731663e-05,

2.0401960690100178e-05

],

"train-error-mean": [

0.0506682,

0.0213034,

0.009941799999999999,

0.014125599999999999,

0.0059878,

0.0020344,

0.0012284,

0.0012284,

0.0009212,

0.0006142

],

"train-error-std": [

0.009200997193782855,

0.0020556122786167634,

0.006076479256938181,

0.0017057689878761427,

0.0018779069625516596,

0.001469605198684327,

0.00026026494193417596,

0.00026026494193417596,

0.0005061973528180487,

0.000506318634853587

],

"train-rmse-mean": [

0.2305582,

0.1598156,

0.11178239999999999,

0.104002,

0.07845200000000001,

0.05541159999999999,

0.0442974,

0.04090819999999999,

0.0325856,

0.0305786

],

"train-rmse-std": [

0.0037922719522734682,

0.01125541159798254,

0.002488777016930202,

0.004049203625405865,

0.0013676392799272755,

0.0037808676041353262,

0.0029096398127603355,

0.002357843285716845,

0.006475110519520113,

0.005710078619423729

],

"test-auc-mean": [

0.9609932000000001,

0.9847534,

0.9972156,

0.9985745999999999,

0.9990005999999999,

0.9992364,

0.9993892000000001,

0.9995353999999999,

0.9999202,

0.9999590000000002

],

"test-auc-std": [

0.01004388481415432,

0.0073321434137637925,

0.001719973674217141,

0.0008885571675474929,

0.0007358270448957148,

0.0007796844489920071,

0.0006559303011753682,

0.00040290524940736174,

9.667760857612425e-05,

4.69212105555625e-05

],

"test-error-mean": [

0.055731600000000006,

0.021188400000000003,

0.009978599999999999,

0.0144336,

0.006294800000000001,

0.0016885999999999997,

0.001228,

0.001228,

0.001228,

0.001228

],

"test-error-std": [

0.015888666194492227,

0.0036532266614597024,

0.004827953421482027,

0.003517125508138713,

0.0031231752688570006,

0.0005741844999649501,

0.0010409403441119958,

0.0010409403441119958,

0.0010409403441119958,

0.0010409403441119958

],

"test-rmse-mean": [

0.2349068,

0.1596758,

0.113232,

0.10586340000000001,

0.08146379999999999,

0.054961800000000005,

0.04472659999999999,

0.040870399999999994,

0.0378632,

0.0355032

],

"test-rmse-std": [

0.009347684363520205,

0.013838324138420807,

0.006189637210693368,

0.008622184470306812,

0.007872053592297248,

0.009846015506792583,

0.010218353460318349,

0.0130493180756697,

0.013843686595701305,

0.013998153155327313

]

}1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

看rmse效果,说明xgboost的分类效果还是不错的。

 类似资料: