当前位置: 首页 > 工具软件 > ml-workspace > 使用案例 >

Spark ML 特征工程之 One-Hot Encoding

司寇阳朔
2023-12-01

1.什么是One-Hot Encoding

One-Hot Encoding 也就是独热码,直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制。在机器学习(Logistic Regression,SVM等)中对于离散型的分类型的数据,需要对其进行数字化比如说性别这一属性,只能有男性或者女性或者其他这三种值,如何对这三个值进行数字化表达?一种简单的方式就是男性为0,女性为1,其他为2,这样做有什么问题?
使用上面简单的序列对分类值进行表示后,进行模型训练时可能会产生一个问题就是特征的因为数字值得不同影响模型的训练效果,在模型训练的过程中不同的值使得同一特征在样本中的权重可能发生变化,假如直接编码成1000,是不是比编码成1对模型的的影响更大。为了解决上述的问题,使训练过程中不受到因为分类值表示的问题对模型产生的负面影响,引入独热码对分类型的特征进行独热码编码。

2.One-Hot Encoding在Spark中的应用

测试数据地址

2.1 数据集预览

数据中字段含义如下:
affairs:Double //是否有婚外情
gender:String //性别 
age:Double //年龄 
yearsmarried:Double //婚龄 
children:String //是否有小孩 
religiousness:Double //宗教信仰程度(5分制,1分表示反对,5分表示非常信仰)
education:Double //学历
occupation:Double //职业(逆向编号的戈登7种分类) 
rating:Double //对婚姻的自我评分(5分制,1表示非常不幸福,5表示非常幸福)

2.2 加载数据集

    val conf = new SparkConf().setMaster("local[4]").setAppName(getClass.getSimpleName).set("spark.testing.memory", "2147480000")
    val sparkContext = new SparkContext(conf)
    val sqlContext = new HiveContext(sparkContext)
    val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")
    val logPath = "E:\\spark_workspace\\spark-study\\src\\main\\files\\lr_test03.json"
    import sqlContext.implicits._

    val dataDF = sqlContext.read.json(logPath).select($"affairs", $"gender", $"age", $"yearsmarried", $"children", $"religiousness", $"education", $"occupation", $"rating")
    

2.3 使用OneHotEncoder处理数据集

    /**要进行OneHotEncoder编码的字段*/
    val categoricalColumns = Array("gender", "children")
    /**采用Pileline方式处理机器学习流程*/
    val stagesArray = new ListBuffer[PipelineStage]()
    for (cate <- categoricalColumns) {
      /**使用StringIndexer 建立类别索引*/
      val indexer = new StringIndexer().setInputCol(cate).setOutputCol(s"${cate}Index")
      /**使用OneHotEncoder将分类变量转换为二进制稀疏向量*/
      val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol(s"${cate}classVec")
      stagesArray.append(indexer,encoder)
    }

2.4 使用VectorAssembler合并所有特征为单个向量

    val numericCols = Array("affairs", "age", "yearsmarried", "religiousness", "education", "occupation", "rating")
    val assemblerInputs = categoricalColumns.map(_ + "classVec") ++ numericCols
    /**使用VectorAssembler将所有特征转换为一个向量*/
    val assembler = new VectorAssembler().setInputCols(assemblerInputs).setOutputCol("features")
    stagesArray.append(assembler)

2.5 以Pipeline的形式运行各个PipelineStage

    val pipeline = new Pipeline()
    pipeline.setStages(stagesArray.toArray)
    /**fit() 根据需要计算特征统计信息*/
    val pipelineModel = pipeline.fit(dataDF)
    /**transform() 真实转换特征*/
    val dataset = pipelineModel.transform(dataDF)
    dataset.show(false)

One-Hot Encoding 之后的数据集结果如下图:

+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+
|affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating|genderIndex|genderclassVec|childrenIndex|childrenclassVec|features                                |
+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+
|0.0    |male  |37.0|10.0        |no      |3.0          |18.0     |7.0       |4.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,37.0,10.0,3.0,18.0,7.0,4.0]|
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |6.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,27.0,4.0,4.0,14.0,6.0,4.0] |
|0.0    |female|32.0|15.0        |yes     |1.0          |12.0     |1.0       |4.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,32.0,15.0,1.0,12.0,1.0,4.0]|
|0.0    |male  |57.0|15.0        |yes     |5.0          |18.0     |6.0       |5.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,57.0,15.0,5.0,18.0,6.0,5.0]|
|0.0    |male  |22.0|0.75        |no      |2.0          |17.0     |6.0       |3.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,22.0,0.75,2.0,17.0,6.0,3.0]|
|0.0    |female|32.0|1.5         |no      |2.0          |17.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,32.0,1.5,2.0,17.0,5.0,5.0] |
|0.0    |female|22.0|0.75        |no      |2.0          |12.0     |1.0       |3.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,0.75,2.0,12.0,1.0,3.0]|
|0.0    |male  |57.0|15.0        |yes     |2.0          |14.0     |4.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,57.0,15.0,2.0,14.0,4.0,4.0]|
|0.0    |female|32.0|15.0        |yes     |4.0          |16.0     |1.0       |2.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,32.0,15.0,4.0,16.0,1.0,2.0]|
|0.0    |male  |22.0|1.5         |no      |4.0          |14.0     |4.0       |5.0   |1.0        |(1,[],[])     |1.0          |(1,[],[])       |[0.0,0.0,0.0,22.0,1.5,4.0,14.0,4.0,5.0] |
|0.0    |male  |37.0|15.0        |yes     |2.0          |20.0     |7.0       |2.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,37.0,15.0,2.0,20.0,7.0,2.0]|
|0.0    |male  |27.0|4.0         |yes     |4.0          |18.0     |6.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,27.0,4.0,4.0,18.0,6.0,4.0] |
|0.0    |male  |47.0|15.0        |yes     |5.0          |17.0     |6.0       |4.0   |1.0        |(1,[],[])     |0.0          |(1,[0],[1.0])   |[0.0,1.0,0.0,47.0,15.0,5.0,17.0,6.0,4.0]|
|0.0    |female|22.0|1.5         |no      |2.0          |17.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,1.5,2.0,17.0,5.0,4.0] |
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,27.0,4.0,4.0,14.0,5.0,4.0] |
|0.0    |female|37.0|15.0        |yes     |1.0          |17.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,37.0,15.0,1.0,17.0,5.0,5.0]|
|0.0    |female|37.0|15.0        |yes     |2.0          |18.0     |4.0       |3.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,37.0,15.0,2.0,18.0,4.0,3.0]|
|0.0    |female|22.0|0.75        |no      |3.0          |16.0     |5.0       |4.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,0.75,3.0,16.0,5.0,4.0]|
|0.0    |female|22.0|1.5         |no      |2.0          |16.0     |5.0       |5.0   |0.0        |(1,[0],[1.0]) |1.0          |(1,[],[])       |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0] |
|0.0    |female|27.0|10.0        |yes     |2.0          |14.0     |1.0       |5.0   |0.0        |(1,[0],[1.0]) |0.0          |(1,[0],[1.0])   |[1.0,1.0,0.0,27.0,10.0,2.0,14.0,1.0,5.0]|
+-------+------+----+------------+--------+-------------+---------+----------+------+-----------+--------------+-------------+----------------+----------------------------------------+

2.6 训练和评估模型

    /**随机分割测试集和训练集数据,指定seed可以固定数据分配*/
    val Array(trainingDF, testDF) = dataset.randomSplit(Array(0.6, 0.4), seed = 12345)
    println(s"trainingDF size=${trainingDF.count()},testDF size=${testDF.count()}")
    val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)
    val predictions = lrModel.transform(testDF).select($"affairs".as("label"), $"features", $"rawPrediction", $"probability", $"prediction")
    predictions.show(false)
    /**使用BinaryClassificationEvaluator来评价我们的模型。在metricName参数中设置度量。*/
    val evaluator = new BinaryClassificationEvaluator()
    evaluator.setMetricName("areaUnderROC")
    val auc= evaluator.evaluate(predictions)
    println(s"areaUnderROC=$auc")

使用model 预测后的数据如下图所示:

+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
|label|features                                 |rawPrediction                           |probability                                |prediction|
+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
|0.0  |[1.0,0.0,0.0,22.0,0.125,4.0,14.0,4.0,5.0]|[24.24907721362884,-24.24907721362884]  |[0.999999999970572,2.942792055040055E-11]  |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.417,1.0,17.0,6.0,4.0]|[21.290119589459323,-21.290119589459323]|[0.9999999994326925,5.673075233382041E-10] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.417,5.0,14.0,1.0,4.0]|[24.17979109657276,-24.17979109657276]  |[0.9999999999684608,3.1539162239002745E-11]|0.0       |
|0.0  |[1.0,1.0,0.0,22.0,0.417,3.0,14.0,3.0,5.0]|[22.67775610810491,-22.67775610810491]  |[0.9999999998583633,1.4163665456478983E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,2.0,12.0,1.0,3.0] |[18.511403509878832,-18.511403509878832]|[0.9999999908672915,9.13270857267764E-9]   |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,4.0,16.0,1.0,5.0] |[25.35929557565844,-25.35929557565844]  |[0.999999999990304,9.69611742832185E-12]   |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,5.0,14.0,3.0,5.0] |[25.260012900022847,-25.260012900022847]|[0.9999999999892919,1.070818300382037E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,0.75,5.0,18.0,1.0,5.0] |[27.56176640273893,-27.56176640273893]  |[0.9999999999989282,1.0717091528412073E-12]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,14.0,4.0,5.0]  |[21.806773356131036,-21.806773356131036]|[0.9999999996615936,3.3840647423836113E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0]  |[22.87962909201085,-22.87962909201085]  |[0.9999999998842548,1.1574529263994485E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,2.0,16.0,5.0,5.0]  |[22.87962909201085,-22.87962909201085]  |[0.9999999998842548,1.1574529263994485E-10]|0.0       |
|0.0  |[1.0,0.0,0.0,22.0,1.5,4.0,16.0,5.0,3.0]  |[22.617887847315348,-22.617887847315348]|[0.9999999998496247,1.5037516453560028E-10]|0.0       |
|0.0  |[1.0,1.0,0.0,22.0,1.5,3.0,16.0,5.0,5.0]  |[23.505953663596607,-23.505953663596607]|[0.9999999999381279,6.187198251529256E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,22.0,4.0,4.0,17.0,5.0,5.0]  |[25.142053761516753,-25.142053761516753]|[0.9999999999879512,1.2048827525325212E-11]|0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,2.0,16.0,6.0,5.0]  |[23.342953469838886,-23.342953469838886]|[0.9999999999271745,7.282560759398736E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,2.0,18.0,6.0,5.0]  |[24.454819713457812,-24.454819713457812]|[0.9999999999760445,2.3955582882827004E-11]|0.0       |
|0.0  |[1.0,0.0,0.0,27.0,1.5,3.0,18.0,5.0,2.0]  |[21.920009187230548,-21.920009187230548]|[0.9999999996978233,3.021766947986581E-10] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,4.0,2.0,18.0,5.0,5.0]  |[24.01911260197023,-24.01911260197023]  |[0.9999999999629634,3.703667040712842E-11] |0.0       |
|0.0  |[1.0,0.0,0.0,27.0,4.0,3.0,16.0,5.0,4.0]  |[22.776375736003562,-22.776375736003562]|[0.9999999998716649,1.2833517289922962E-10]|0.0       |
|0.0  |[1.0,1.0,0.0,27.0,4.0,2.0,18.0,6.0,1.0]  |[18.629921259118063,-18.629921259118063]|[0.999999991887999,8.112000996701378E-9]   |0.0       |
+-----+-----------------------------------------+----------------------------------------+-------------------------------------------+----------+
 类似资料: