分类法 Classification - EX 2: Normal and Shrinkage Linear Discriminant Analysis for classification

优质
小牛编辑
141浏览
2023-12-01

分类法/范例二: Normal and Shrinkage Linear Discriminant Analysis for classification

http://scikit-learn.org/stable/auto_examples/classification/plot_lda.html

这个范例用来展示scikit-learn 如何使用Linear Discriminant Analysis (LDA) 线性判别分析来达成资料分类的目的

  1. 利用 sklearn.datasets.make_blobs 产生测试资料
  2. 利用自定义函数 generate_data 产生具有数个特征之资料集,其中仅有一个特征对于资料分料判断有意义
  3. 使用LinearDiscriminantAnalysis来达成资料判别
  4. 比较于LDA演算法中,开启 shrinkage 前后之差异

(一)产生测试资料

从程式码来看,一开始主要为自定义函数generate_data(n_samples, n_features),这个函数的主要目的为产生一组测试资料,总资料列数为n_samples,每一列共有n_features个特征。而其中只有第一个特征得以用来判定资料类别,其他特征则毫无意义。make_blobs负责产生单一特征之资料后,利用`np.random.randn` 乱数产生其他`n_features - 1`个特征,之后利用np.hstack以”水平” (horizontal)方式连接X以及乱数产生之特征资料。

  1. %matplotlib inline
  2. from __future__ import division
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.datasets import make_blobs
  6. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  7. n_train = 20 # samples for training
  8. n_test = 200 # samples for testing
  9. n_averages = 50 # how often to repeat classification
  10. n_features_max = 75 # maximum number of features
  11. step = 4 # step size for the calculation
  12. def generate_data(n_samples, n_features):
  13. X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])
  14. # add non-discriminative features
  15. if n_features > 1:
  16. X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
  17. return X, y

我们可以用以下的程式码来测试自定义函式,结果回传了X (10x5矩阵)及y(10个元素之向量),我们可以使用pandas.DataFrame套件来观察资料

  1. X, y = generate_data(10, 5)
  2. import pandas as pd
  3. pd.set_option('precision',2)
  4. df=pd.DataFrame(np.hstack([y.reshape(10,1),X]))
  5. df.columns = ['y', 'X0', 'X1', 'X2', 'X2', 'X4']
  6. print(df)

结果显示如下。。我们可以看到只有X的第一行特征资料(X0) 与目标数值 y 有一个明确的对应关係,也就是y为1时,数值较大。

  1. y X0 X1 X2 X2 X4
  2. 0 1 0.38 0.35 0.80 -0.97 -0.68
  3. 1 1 2.41 0.31 -1.47 0.10 -1.39
  4. 2 1 1.65 -0.99 -0.12 -0.38 0.18
  5. 3 0 -4.86 0.14 -0.80 1.13 -1.31
  6. 4 1 -0.06 -1.99 -0.70 -1.26 -1.64
  7. 5 0 -1.51 -1.74 -0.83 0.74 -2.07
  8. 6 0 -2.50 0.44 -0.45 -0.55 -0.42
  9. 7 1 1.55 1.38 0.93 -1.44 0.27
  10. 8 0 -1.95 0.32 -0.28 0.02 0.07
  11. 9 0 -0.58 -0.07 -1.01 0.15 -1.84

(二)改变特征数量并测试shrinkage之功能

接下来程式码裏有两段迴圈,外圈改变特征数量。内圈则多次尝试LDA之以求精准度。使用LinearDiscriminantAnalysis来训练分类器,过程中以shrinkage='auto'以及shrinkage=None来控制shrinkage之开关,将分类器分别以clf1以及clf2储存。之后再产生新的测试资料将准确度加入score_clf1score_clf2裏,离开内迴圈之后除以总数以求平均。

  1. acc_clf1, acc_clf2 = [], []
  2. n_features_range = range(1, n_features_max + 1, step)
  3. for n_features in n_features_range:
  4. score_clf1, score_clf2 = 0, 0
  5. for _ in range(n_averages):
  6. X, y = generate_data(n_train, n_features)
  7. clf1 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X, y)
  8. clf2 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X, y)
  9. X, y = generate_data(n_test, n_features)
  10. score_clf1 += clf1.score(X, y)
  11. score_clf2 += clf2.score(X, y)
  12. acc_clf1.append(score_clf1 / n_averages)
  13. acc_clf2.append(score_clf2 / n_averages)

(三)显示LDA判别结果

这个范例主要希望能得知shrinkage的功能,因此画出两条分类准确度的曲线。纵轴代表平均的分类准确度,而横轴代表的是features_samples_ratio 顾名思义,它是模拟资料中,特征数量与训练资料列数的比例。当特征数量为75且训练资料列数仅有20笔时,features_samples_ratio = 3.75 由于资料列数过少,导致准确率下降。而此时shrinkage演算法能有效维持LDA演算法的准确度。

  1. features_samples_ratio = np.array(n_features_range) / n_train
  2. fig = plt.figure(figsize=(10,6), dpi=300)
  3. plt.plot(features_samples_ratio, acc_clf1, linewidth=2,
  4. label="Linear Discriminant Analysis with shrinkage", color='r')
  5. plt.plot(features_samples_ratio, acc_clf2, linewidth=2,
  6. label="Linear Discriminant Analysis", color='g')
  7. plt.xlabel('n_features / n_samples')
  8. plt.ylabel('Classification accuracy')
  9. plt.legend(loc=1, prop={'size': 10})
  10. plt.show()

png

(四)完整程式码

Python source code: plot_lda.py

  1. from __future__ import division
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.datasets import make_blobs
  5. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  6. n_train = 20 # samples for training
  7. n_test = 200 # samples for testing
  8. n_averages = 50 # how often to repeat classification
  9. n_features_max = 75 # maximum number of features
  10. step = 4 # step size for the calculation
  11. def generate_data(n_samples, n_features):
  12. """Generate random blob-ish data with noisy features.
  13. This returns an array of input data with shape `(n_samples, n_features)`
  14. and an array of `n_samples` target labels.
  15. Only one feature contains discriminative information, the other features
  16. contain only noise.
  17. """
  18. X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], [2]])
  19. # add non-discriminative features
  20. if n_features > 1:
  21. X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
  22. return X, y
  23. acc_clf1, acc_clf2 = [], []
  24. n_features_range = range(1, n_features_max + 1, step)
  25. for n_features in n_features_range:
  26. score_clf1, score_clf2 = 0, 0
  27. for _ in range(n_averages):
  28. X, y = generate_data(n_train, n_features)
  29. clf1 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X, y)
  30. clf2 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X, y)
  31. X, y = generate_data(n_test, n_features)
  32. score_clf1 += clf1.score(X, y)
  33. score_clf2 += clf2.score(X, y)
  34. acc_clf1.append(score_clf1 / n_averages)
  35. acc_clf2.append(score_clf2 / n_averages)
  36. features_samples_ratio = np.array(n_features_range) / n_train
  37. plt.plot(features_samples_ratio, acc_clf1, linewidth=2,
  38. label="Linear Discriminant Analysis with shrinkage", color='r')
  39. plt.plot(features_samples_ratio, acc_clf2, linewidth=2,
  40. label="Linear Discriminant Analysis", color='g')
  41. plt.xlabel('n_features / n_samples')
  42. plt.ylabel('Classification accuracy')
  43. plt.legend(loc=1, prop={'size': 12})
  44. plt.suptitle('Linear Discriminant Analysis vs. \
  45. shrinkage Linear Discriminant Analysis (1 discriminative feature)')
  46. plt.show()