类神经网路 Neural_Networks - Ex 2: Restricted Boltzmann Machine features for digit classification

优质
小牛编辑
139浏览
2023-12-01

http://scikit-learn.org/stable/auto_examples/neural_networks/plot_rbm_logistic_classification.html

此范例将使用BernoulliRBM特征选取方法,提升手写数字识别的精确率,伯努利限制玻尔兹曼机器模型(`BernoulliRBM

`)将可以对数据做有效的非线性
特征提取的处理。
为了让此模型训练出来更为强健,将输入的图档,分别做上左右下,一像素的平移,用以增加更多训练资料,
训练网路的参数是使用grid search演算法,但此训练太耗费时间,因此不再这重现,。
此范例结果将比较,
1.使用原本的像素值做的逻辑回归
2.使用BernoulliRBM做特征选取的逻辑回归
结果将显示:使用BernoulliRBM将可以提升分类的准确度。

(一)引入函式库与资料

  1. from __future__ import print_function
  2. print(__doc__)
  3. # Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
  4. # License: BSD
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from scipy.ndimage import convolve
  8. from sklearn import linear_model, datasets, metrics
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.neural_network import BernoulliRBM
  11. from sklearn.pipeline import Pipeline

(二)资料前处理、读取资料、选取模型

  1. def nudge_dataset(X, Y):
  2. """
  3. 此副函式是用来将输入资料的数字图形,分别做上左右下一像素的平移,目的是制造更多的训练资料让模型训练出来更强健
  4. """
  5. direction_vectors = [
  6. [[0, 1, 0],
  7. [0, 0, 0],
  8. [0, 0, 0]],
  9. [[0, 0, 0],
  10. [1, 0, 0],
  11. [0, 0, 0]],
  12. [[0, 0, 0],
  13. [0, 0, 1],
  14. [0, 0, 0]],
  15. [[0, 0, 0],
  16. [0, 0, 0],
  17. [0, 1, 0]]]
  18. shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
  19. weights=w).ravel()
  20. X = np.concatenate([X] +
  21. [np.apply_along_axis(shift, 1, X, vector)
  22. for vector in direction_vectors])
  23. Y = np.concatenate([Y for _ in range(5)], axis=0)
  24. return X, Y
  25. # Load Data
  26. digits = datasets.load_digits()
  27. X = np.asarray(digits.data, 'float32')
  28. X, Y = nudge_dataset(X, digits.target)
  29. X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 将灰阶影像降尺度降到[0,1]
  30. # 将资料切割成训练集与测试集
  31. X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
  32. test_size=0.2,
  33. random_state=0)
  34. # Models we will use
  35. logistic = linear_model.LogisticRegression()
  36. rbm = BernoulliRBM(random_state=0, verbose=True)
  37. classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])

(三)设定模型参数与训练模型

  1. # 参数选择需使用cross-validation去比较
  2. # 此参数是使用GridSearchCV找出来的. Here we are not performing cross-validation to save time.
  3. #GridSratch 就是将参数设定好,跑过全部参数后去找结果最好的一组参数
  4. rbm.learning_rate = 0.06
  5. rbm.n_iter = 20
  6. #.n_components = 100 表示隐藏层单元为100,即表示萃取出100个特征,特征萃取的越多准确率会越高,但越耗时间
  7. rbm.n_components = 100
  8. logistic.C = 6000.0
  9. # Training RBM-Logistic Pipeline
  10. classifier.fit(X_train, Y_train)
  11. # Training Logistic regression
  12. logistic_classifier = linear_model.LogisticRegression(C=100.0)
  13. logistic_classifier.fit(X_train, Y_train)

(四)评估模型的分辨准确率

  1. print()
  2. print("Logistic regression using RBM features:\n%s\n" % (
  3. metrics.classification_report(
  4. Y_test,
  5. classifier.predict(X_test))))
  6. print("Logistic regression using raw pixel features:\n%s\n" % (
  7. metrics.classification_report(
  8. Y_test,
  9. logistic_classifier.predict(X_test))))

Ex 2: Restricted Boltzmann Machine features for digit classification - 图1

图1:使用RBM演算法后准确率为0.95

Ex 2: Restricted Boltzmann Machine features for digit classification - 图2

图2:不使用任何特征选取方法做的做的逻辑回归准确率0.77

(五)画出100个RBM萃取出的特征

  1. plt.figure(figsize=(4.2, 4))
  2. for i, comp in enumerate(rbm.components_):
  3. plt.subplot(10, 10, i + 1)
  4. plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
  5. interpolation='nearest')
  6. plt.xticks(())
  7. plt.yticks(())
  8. plt.suptitle('100 components extracted by RBM', fontsize=16)
  9. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
  10. plt.show()

Ex 2: Restricted Boltzmann Machine features for digit classification - 图3

图3:使用RBM演算法,寻找出来的特征

(六)完整程式码

  1. from __future__ import print_function
  2. print(__doc__)
  3. # Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
  4. # License: BSD
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from scipy.ndimage import convolve
  8. from sklearn import linear_model, datasets, metrics
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.neural_network import BernoulliRBM
  11. from sklearn.pipeline import Pipeline
  12. ###############################################################################
  13. # Setting up
  14. def nudge_dataset(X, Y):
  15. """
  16. This produces a dataset 5 times bigger than the original one,
  17. by moving the 8x8 images in X around by 1px to left, right, down, up
  18. """
  19. direction_vectors = [
  20. [[0, 1, 0],
  21. [0, 0, 0],
  22. [0, 0, 0]],
  23. [[0, 0, 0],
  24. [1, 0, 0],
  25. [0, 0, 0]],
  26. [[0, 0, 0],
  27. [0, 0, 1],
  28. [0, 0, 0]],
  29. [[0, 0, 0],
  30. [0, 0, 0],
  31. [0, 1, 0]]]
  32. shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
  33. weights=w).ravel()
  34. X = np.concatenate([X] +
  35. [np.apply_along_axis(shift, 1, X, vector)
  36. for vector in direction_vectors])
  37. Y = np.concatenate([Y for _ in range(5)], axis=0)
  38. return X, Y
  39. # Load Data
  40. digits = datasets.load_digits()
  41. X = np.asarray(digits.data, 'float32')
  42. X, Y = nudge_dataset(X, digits.target)
  43. X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 0-1 scaling
  44. X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
  45. test_size=0.2,
  46. random_state=0)
  47. # Models we will use
  48. logistic = linear_model.LogisticRegression()
  49. rbm = BernoulliRBM(random_state=0, verbose=True)
  50. classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
  51. ###############################################################################
  52. # Training
  53. # Hyper-parameters. These were set by cross-validation,
  54. # using a GridSearchCV. Here we are not performing cross-validation to
  55. # save time.
  56. rbm.learning_rate = 0.06
  57. rbm.n_iter = 20
  58. # More components tend to give better prediction performance, but larger
  59. # fitting time
  60. rbm.n_components = 100
  61. logistic.C = 6000.0
  62. # Training RBM-Logistic Pipeline
  63. classifier.fit(X_train, Y_train)
  64. # Training Logistic regression
  65. logistic_classifier = linear_model.LogisticRegression(C=100.0)
  66. logistic_classifier.fit(X_train, Y_train)
  67. ###############################################################################
  68. # Evaluation
  69. print()
  70. print("Logistic regression using RBM features:\n%s\n" % (
  71. metrics.classification_report(
  72. Y_test,
  73. classifier.predict(X_test))))
  74. print("Logistic regression using raw pixel features:\n%s\n" % (
  75. metrics.classification_report(
  76. Y_test,
  77. logistic_classifier.predict(X_test))))
  78. ###############################################################################
  79. # Plotting
  80. plt.figure(figsize=(4.2, 4))
  81. for i, comp in enumerate(rbm.components_):
  82. plt.subplot(10, 10, i + 1)
  83. plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
  84. interpolation='nearest')
  85. plt.xticks(())
  86. plt.yticks(())
  87. plt.suptitle('100 components extracted by RBM', fontsize=16)
  88. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
  89. plt.show()