当前位置: 首页 > 工具软件 > tf-coreml > 使用案例 >

TF-IDF + K-Means 中文聚类例子 - scala

牛兴安
2023-12-01

Demo仅供参考

  • 使用spark1.6

import java.io.{BufferedReader, InputStreamReader}
import java.util.Arrays

import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.hadoop.fs.FSDataInputStream
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.{SparseVector, Vectors}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer
/**
  * Created by Zsh on 1/22 0022.
  */
object tfid {

  def main(args: Array[String]): Unit = {
   
    val conf = new
        SparkConf().setAppName("TF-IDF Clustering")
      .setMaster("yarn-client")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    //    val sentenceData = sqlContext.createDataFrame(Seq(
    //      (1, tokenizer2("利用特征向量和标签产生一个预测模型。 MLlib使用Pipeline代表这个工作流")),
    //      (2, tokenizer2("这些变化发生在ML包里面。MLlib模块下现在有两个包:MLlib和ML。ML把整个机器学")),
    //      (3, tokenizer2("Mahout是hadoop的一个机器学习库,主要的编程模型是MapReduce;Spark ML则是基于Spark的机器学习,Spark自身拥有MLlib作为机器学习库。")),
    //        (3, tokenizer2("日本东京电视台的人气综艺节目《开运鉴定团》主要对古董进行鉴定不过偶尔也会发生失误的状况节目开播以来最重大的发现日前他们在节目里鉴定")),
    //          (3, tokenizer2("对许多人来说,看着老爸老妈现在的样子,大概很难想象他们曾经也是青春靓丽,甚至颜值惊人。然而,谁没年轻过呢?对于这个话题,最近又有不"))
    //    )).toDF("label", "sentence")

    val rawTrainingData = sc.textFile("/wcc.txt")
    val dataFrame = rawTrainingData.map(x=>{(0,tokenizer2(x))})
      .persist(StorageLevel.MEMORY_AND_DISK).toDF("label", "sentence")
    //        val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
    //        val wordsData = tokenizer.transform(sentenceData)
    //    println(wordsData.select("words"))
    //        wordsData.show(false)

    //    val frame = sentenceData.map(x=>{(x.getAs("label"),x.getAs("sentence"))}).toDF("label", "sentence")
    //    val str = tokenizer2("老师都快放假啊李开复啊可是对方")
    val numClusters = 10  //聚类数
    val numIterations = 30
    val runTimes = 3
    var clusterIndex: Int = 0


    val hashingTF = new HashingTF()
      . setInputCol("sentence").setOutputCol("rawFeatures").setNumFeatures(100000)
    val featurizedData = hashingTF.transform(dataFrame)
    featurizedData.show(false)
    println(featurizedData.count())
    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val idfModel = idf.fit(featurizedData)
    val rescaledData  = idfModel.transform(featurizedData)
    println(rescaledData)
    rescaledData.select("features","label").show(false)
    //    val value = rescaledData.select("features", "label","sentence")
    //      //      .map(_.get(0))
    //      .map(x => {
    //      val vector = x.get(0).asInstanceOf[org.apache.spark.mllib.linalg.Vector]
    //      (vector,x.get(2))
    //    })
    val value = rescaledData.select("features", "label","sentence")
      .map{
        case Row(features:org.apache.spark.mllib.linalg.Vector,label:Int,sentence)=>
          (features,sentence)
      }


  //训练
    val clusters: KMeansModel =
      KMeans.train(value.map(_._1), numClusters, numIterations, runTimes)

    println("Cluster Number:" + clusters.clusterCenters.length)
    println("Cluster Centers Information Overview:")
    clusters.clusterCenters.foreach(
      x => {
        println("聚类质心点向量:" + clusterIndex + ":")
        println(x)
        clusterIndex += 1
      })


    // 输出本次聚类操作的收敛性,此值越低越好
    val kMeansCost = clusters.computeCost(value.map(_._1))
    println("K-Means Cost: " + kMeansCost)
    //begin to check which cluster each test data belongs to based on the clustering result
    // 输出每组数据及其所属的子集索引
    value.map(x=>{
      //预测
      (clusters.predict(x._1)+":"+x._2.toString)
    }).saveAsTextFile("/0123")
  }
  def tokenizer2(line: String): Seq[String] = {
    val reg1 = "@\\w{2,20}:".r
    val reg2 = "http://[0-9a-zA-Z/\\?&#%$@\\=\\\\]+".r
    //    println("stopwordSet,stopwordSet:"+stopwordSet.size)
    AnsjSegment(line)
      .split(",")
      .filter(_!=null)
      .filter(token => !reg1.pattern.matcher(token).matches)
      .filter(token => !reg2.pattern.matcher(token).matches)
      .filter(token => !stopwordSet.contains(token))
      .toSeq
  }
  def AnsjSegment(line: String): String={
    val StopNatures="""w","",null,"s", "f", "b", "z", "r", "q", "d", "p", "c", "uj", "ul","en", "y", "o", "h", "k", "x"""
    val KeepNatures=List("n","v","a","m","t")
    val StopWords=Arrays.asList("的", "是","了")  //Arrays.asList(stopwordlist.toString())
    //val filter = new FilterRecognition()
    //加入停用词
    //filter.insertStopWords(StopWords)
    //加入停用词性
    //filter.insertStopNatures(StopNatures)
    //filter.insertStopRegex("小.*?")
    //此步骤将会只取分词,不附带词性
    //for (i <- Range(0, filter1.size())) {
    //word += words.get(i).getName
    //}
    val words = ToAnalysis.parse(line)
    val word = ArrayBuffer[String]()
    for (i <- Range(0,words.size())) { //KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&
      if(KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&words.get(i).getName.length()>=2)
        word += words.get(i).getName
    }
    word.mkString(",")
  }
  var stopwordSet: Set[String] = getStopFile()

  def getStopFile():Set[String]={
    //集群请跑下面的路径
    var inputStream: FSDataInputStream = null
    var bufferedReader: BufferedReader = null
    val stopword = ArrayBuffer[String]()
    try {
      //获取到HDFS的输入流,可以参考上一篇文档
      val stopWordsCn=ConfigurationManager.getProperty(Constants.STOP_WORDS_CN)
      inputStream = HDFSUtil.getFSDataInputStream(stopWordsCn)
      //转成缓冲流
      bufferedReader = new BufferedReader(new InputStreamReader(inputStream))
      //一次读取一行
      var lineTxt: String = bufferedReader.readLine()
      while (lineTxt != null) {
        //      println("lineTxt:"+lineTxt)
        stopword +=lineTxt
        lineTxt = bufferedReader.readLine()
      }
      stopwordSet = stopword.toSet
    }catch{
      case e: Exception => e.printStackTrace()
    }finally{
      if (bufferedReader != null) {
        bufferedReader.close()
      }
      if (inputStream != null) {
        HDFSUtil.close(inputStream)
      }
    }
    //println("结束词语个数:"+stopwordSet.size)
    stopwordSet
  }
}

全部配置文件-自行提取所需配置

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.izhonghong</groupId>
    <artifactId>mission-center-new</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.6</maven.compiler.source>
        <maven.compiler.target>1.6</maven.compiler.target>
        <encoding>UTF-8</encoding>
        <scala.tools.version>2.10</scala.tools.version>
        <scala.version>2.10.6</scala.version>
        <hbase.version>1.2.2</hbase.version>
    </properties>

    <dependencies>
       <!-- <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.1.0</version>
        </dependency>-->
        <!--<dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>1.6.0</version>
        </dependency>-->
       <!-- <dependency>
            <groupId>com.hankcs</groupId>
            <artifactId>hanlp</artifactId>
            <version>portable-1.5.0</version>
        </dependency>-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.10</artifactId>
            <version>1.6.0</version>
        </dependency>

        <dependency>
            <groupId>org.ansj</groupId>
            <artifactId>ansj_seg</artifactId>
            <version>5.0.4</version>
        </dependency>


        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.10.6</version>
        </dependency>
        <dependency>
            <groupId>org.apache.kafka</groupId>
            <artifactId>kafka-clients</artifactId>
            <version>0.10.0.0</version>
        </dependency>



        <dependency>
            <groupId>net.sf.json-lib</groupId>
            <classifier>jdk15</classifier>
            <artifactId>json-lib</artifactId>
            <version>2.4</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>

        <!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.10</artifactId>
            <version>2.1.1</version> </dependency> -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.10</artifactId>
            <version>1.6.2</version>
            <exclusions>
                <exclusion>
                    <artifactId>scala-library</artifactId>
                    <groupId>org.scala-lang</groupId>
                </exclusion>
            </exclusions>
        </dependency>

        <!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId>
            <version>2.1.1</version> <scope>provided</scope> </dependency> -->
        <dependency>
            <groupId>com.huaban</groupId>
            <artifactId>jieba-analysis</artifactId>
            <version>1.0.2</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.14</version>
        </dependency>


        <dependency>
            <groupId>redis.clients</groupId>
            <artifactId>jedis</artifactId>
            <version>2.9.0</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>${scala.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-server</artifactId>
            <version>1.2.2</version>
            <exclusions>
                <exclusion>
                    <artifactId>servlet-api-2.5</artifactId>
                    <groupId>org.mortbay.jetty</groupId>
                </exclusion>
            </exclusions>
        </dependency>
      <!--  <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.18</version>
        </dependency>-->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.10</artifactId>
            <version>1.6.2</version>
            <!-- <version>2.1.1</version> -->
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-client</artifactId>
            <version>2.7.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-common</artifactId>
            <version>2.7.0</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-hdfs</artifactId>
            <version>2.7.0</version>
            <exclusions>
                <exclusion>
                    <groupId>javax.servlet.jsp</groupId>
                    <artifactId>*</artifactId>
                </exclusion>
                <exclusion>
                    <artifactId>servlet-api</artifactId>
                    <groupId>javax.servlet</groupId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.10</artifactId>
            <version>1.6.2</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.39</version>
        </dependency>
        <!--<dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-server</artifactId>
            <version>1.2.2</version>
        </dependency>-->

        <!-- Test -->
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.11</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.specs2</groupId>
            <artifactId>specs2_${scala.tools.version}</artifactId>
            <version>1.13</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.scalatest</groupId>
            <artifactId>scalatest_${scala.tools.version}</artifactId>
            <version>2.0.M6-SNAP8</version>
            <scope>test</scope>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>net.alchim31.maven</groupId>
                <artifactId>scala-maven-plugin</artifactId>
                <version>3.2.0</version>
                <executions>
                    <execution>
                        <goals>
                            <goal>compile</goal>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-jar-plugin</artifactId>
                <configuration>
                    <archive>
                        <manifest>
                            <addClasspath>true</addClasspath>
                            <classpathPrefix>lib/</classpathPrefix>
                            <mainClass></mainClass>
                        </manifest>
                    </archive>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-dependency-plugin</artifactId>
                <executions>
                    <execution>
                        <id>copy</id>
                        <phase>package</phase>
                        <goals>
                            <goal>copy-dependencies</goal>
                        </goals>
                        <configuration>
                            <outputDirectory>${project.build.directory}/lib</outputDirectory>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
    <!-- <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId>
        <configuration> <archive> <manifest> 这里要替换成jar包main方法所在类 <mainClass>com.sf.pps.client.IntfClientCall</mainClass>
        </manifest> <manifestEntries> <Class-Path>.</Class-Path> </manifestEntries>
        </archive> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef>
        </descriptorRefs> </configuration> <executions> <execution> <id>make-assembly</id>
        this is used for inheritance merges <phase>package</phase> 指定在打包节点执行jar包合并操作
        <goals> <goal>single</goal> </goals> </execution> </executions> </plugin>
        </plugins> </build> -->

</project>

 

 类似资料: