SparkSession.scala

华宪
2023-12-01

Spark源码之SparkSession,Spark版本号2.2.0

//SparkSession 源码
/**使用数据集和数据框架API编程Spark的入口点。
*在预先创建的环境中(例如命令行、笔记本电脑),使用生成器获取现有会话:
* SparkSession.builder().getOrCreate()
*构建器也可以用来创建一个新的会话:
*   SparkSession.builder
*     .master("local")
*     .appName("Word Count")
*     .config("spark.some.config.option", "some-value")
*     .getOrCreate()
*
*/


@InterfaceStability.Stable
class SparkSession private(
	//与此Spark会话关联的Spark上下文。
    @transient val sparkContext: SparkContext,
    //如果提供,请使用现有的共享状态而不是创建一个新的
    @transient private val existingSharedState: Option[SharedState],
    //如果提供,则从父级继承所有会话状态(即临时视图、SQL配置、UDF等)。
    @transient private val parentSessionState: Option[SessionState],
    /*扩展点:
    * -分析器规则。
	* -检查分析规则
	* -优化器规则。
	* -规划策略。
	* -自定义解析器。
	* -(外部)目录监听器。
	*/
    @transient private[sql] val extensions: SparkSessionExtensions)
  extends Serializable with Closeable with Logging { self =>

//构造器
  private[sql] def this(sc: SparkContext) {
    this(sc, None, None, new SparkSessionExtensions)
  }
//沿着堆栈追踪,直到发现第一个spark方法,还跟踪第一个(最深的)用户方法、文件和行。
  sparkContext.assertNotStopped()


  def version: String = SPARK_VERSION

  /* ----------------------- *
   |  与会话相关的状态  |
   * ----------------------- */

//会话间共享的状态,包括“SparkContext”、缓存数据、侦听器和与外部系统交互的目录。
  @InterfaceStability.Unstable
  @transient
  lazy val sharedState: SharedState = {
    //如果存在共享状态就用,不存在使用括号内新建的
    existingSharedState.getOrElse(new SharedState(sparkContext))
  }

 
 //如果parentSessionState不为空,sessionState将是parentSessionState的拷贝
  @InterfaceStability.Unstable
  @transient
  lazy val sessionState: SessionState = {
    parentSessionState
      .map(_.clone(this))
      .getOrElse {
        SparkSession.instantiateSessionState(
          SparkSession.sessionStateClassName(sparkContext.conf),
          self)
      }
  }

 //包装SQLContext,以便向后兼容
  @transient
  val sqlContext: SQLContext = new SQLContext(this)

  //Spark运行时配置接口
  @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf)

  //监听执行指标
  @Experimental
  @InterfaceStability.Evolving
  def listenerManager: ExecutionListenerManager = sessionState.listenerManager


  //一组方法,这些方法被认为是实验性的,但可用于连接到查询计划器以获得高级功能。
  @Experimental
  @InterfaceStability.Unstable
  def experimental: ExperimentalMethods = sessionState.experimentalMethods

  //用于注册用户定义函数(UDF)的方法集合。
  def udf: UDFRegistration = sessionState.udfRegistration

  //返回允许管理“this”上所有活动的“StreamingQuery”的“StreamingQueryManager”。
  @Experimental
  @InterfaceStability.Unstable
  def streams: StreamingQueryManager = sessionState.streamingQueryManager

  //使用隔离的SQL配置启动新会话,临时表、已注册的函数是隔离的,但共享底层的“SparkContext”和缓存数据。
  def newSession(): SparkSession = {
    new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
  }

 //创建此“SparkSession”的相同副本,共享基础“SparkContext”和共享状态。复制此会话的所有状态(即SQL配置、临时表、已注册函数),并使用与此会话相同的共享状态设置克隆的会话。克隆的会话独立于此会话,也就是说,任一会话中的任何非全局更改都不会反映在另一个会话中
 
  private[sql] def cloneSession(): SparkSession = {
    val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
    result.sessionState // 强制复制sessiState,因为sessionState是lazy所以一调用就激活了
    result
  }


  /* ---------------------- *
   |  创建 DataFrames 的方法  |
   * --------------------- */

 //返回一个没有行或列的“数据帧”。
  @transient
  lazy val emptyDataFrame: DataFrame = {
    createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil))
  }

  //创建一个类型为T的新[[Dataset]],其中包含0个元素。
  @Experimental
  @InterfaceStability.Evolving
  def emptyDataset[T: Encoder]: Dataset[T] = {
    val encoder = implicitly[Encoder[T]]
    new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder)
  }

  //从产品的RDD(例如case类,元组)中创建一个DataFrame。
  @Experimental
  @InterfaceStability.Evolving
  def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
    SparkSession.setActiveSession(this)
    val encoder = Encoders.product[A]
    Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder))
  }

 //从产品的本地序列创建一个“数据帧”。
  @Experimental
  @InterfaceStability.Evolving
  def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
    SparkSession.setActiveSession(this)
    val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
    val attributeSeq = schema.toAttributes
    Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
  }

// 使用给定架构从包含[[行]]的“RDD”创建“DataFrame”。确保所提供RDD的每个[[Row]]的结构与所提供的模式匹配是很重要的。否则,将出现运行时异常。
  @DeveloperApi
  @InterfaceStability.Evolving
  def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
    createDataFrame(rowRDD, schema, needsConversion = true)
  }

 //使用给定架构从包含[[Row]]的“JavaRDD”创建“DataFrame”。确保所提供RDD的每个[[Row]]的结构与所提供的模式匹配是很重要的。否则,将出现运行时异常。
  @DeveloperApi
  @InterfaceStability.Evolving
  def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
    createDataFrame(rowRDD.rdd, schema)
  }

 //从`java.util.List`包含使用给定架构的[[行]]的。确保所提供列表的每个[[Row]]的结构与所提供的模式匹配是很重要的。否则,将出现运行时异常。
  @DeveloperApi
  @InterfaceStability.Evolving
  def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
    Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
  }

 //将模式应用于Java bean的RDD。
 //警告:由于在Java Bean中没有保证字段的顺序,
 //SELECT *查询将以未定义的顺序返回列。
  def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
    val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
    val className = beanClass.getName
    val rowRdd = rdd.mapPartitions { iter =>
    // BeanInfo是不可序列化的,因此我们必须为每个分区远程重新发现它。
      SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
    }
    Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
  }

  //将模式应用于Java bean的RDD。
  //警告:由于在Java Bean中没有保证字段的顺序,
  //SELECT *查询将以未定义的顺序返回列。
  def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
    createDataFrame(rdd.rdd, beanClass)
  }

  //对Java bean列表应用模式。
  //警告:由于在Java Bean中没有保证字段的顺序,
  //SELECT *查询将以未定义的顺序返回列。 
  def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
    val attrSeq = getSchema(beanClass)
    val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
    Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
  }

 //将为外部数据源创建的“BaseRelation”转换为“数据帧”。
  def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
    Dataset.ofRows(self, LogicalRelation(baseRelation))
  }

  /* ------------------------------- *
   |  创建 DataSets  的方法            |
   * ------------------------------- */

 //从给定类型的本地数据序列创建[[Dataset]]。此方法需要一个编码器(用于将类型为“T”的JVM对象与内部Spark SQL表示形式进行转换),该编码器通常通过“SparkSession”的隐式自动创建,或者可以通过调用[[Encoders]]上的静态方法显式创建。
  @Experimental
  @InterfaceStability.Evolving
  def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
    val enc = encoderFor[T]
    val attributes = enc.schema.toAttributes
    val encoded = data.map(d => enc.toRow(d).copy())
    val plan = new LocalRelation(attributes, encoded)
    Dataset[T](self, plan)
  }

  //创建方法的重载:数据源RDD
  @Experimental
  @InterfaceStability.Evolving
  def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
    Dataset[T](self, ExternalRDD(data, self))
  }

 //数据源java.util.List
  @Experimental
  @InterfaceStability.Evolving
  def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
    createDataset(data.asScala)
  }

 //创建一个[[Dataset]],其中有一个名为“id”的“LongType”列,包含范围从0到“end”(独占)的元素,步长值为1。
  @Experimental
  @InterfaceStability.Evolving
  def range(end: Long): Dataset[java.lang.Long] = range(0, end)

 //创建一个[[Dataset]],其中有一个名为“id”的“LongType”列,包含从“start”到“end”(独占)范围内的元素,步长值为1。
  @Experimental
  @InterfaceStability.Evolving
  def range(start: Long, end: Long): Dataset[java.lang.Long] = {
    range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
  }

  //创建一个[[Dataset]],其中有一个名为“id”的“LongType”列,该列包含从“start”到“end”(独占)范围内的元素,并带有步长值。
  @Experimental
  @InterfaceStability.Evolving
  def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
    range(start, end, step, numPartitions = sparkContext.defaultParallelism)
  }

 //创建一个[[Dataset]],其中有一个名为“id”的“LongType”列,该列包含从“start”到“end”(独占)范围内的元素,并指定了一个step值和分区号。
  @Experimental
  @InterfaceStability.Evolving
  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
    new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG)
  }

 //从RDD[Row]创建一个DataFrame。
 //用户可以指定是否将输入的行转换为Catalyst行。
  private[sql] def internalCreateDataFrame(
      catalystRows: RDD[InternalRow],
      schema: StructType): DataFrame = {
    // TODO: 当rowRDD是另一个数据帧并被应用时,使用MutableProjection
    // 在任何字段数据类型上,模式都与现有模式不同
    val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
    Dataset.ofRows(self, logicalPlan)
  }

//从RDD[Row]创建一个DataFrame。
//用户可以指定是否应该将输入行转换为Catalyst行。
  private[sql] def createDataFrame(
      rowRDD: RDD[Row],
      schema: StructType,
      needsConversion: Boolean) = {
    // TODO: 当rowRDD是另一个数据帧并被应用时,使用MutableProjection
    // 在任何字段数据类型上,模式都与现有模式不同。
    val catalystRows = if (needsConversion) {
      val encoder = RowEncoder(schema)
      rowRDD.map(encoder.toRow)
    } else {
      rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
    }
    val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
    Dataset.ofRows(self, logicalPlan)
  }


  /* ------------------------- *
   |  目录相关的方法              |
   * ------------------------- */

 //用户可以通过该接口创建、删除、更改或查询底层数据库、表、函数等。
  @transient lazy val catalog: Catalog = new CatalogImpl(self)

  //以“数据帧”的形式返回指定的表/视图。
  //tableName 参数 : 是指定表或视图的限定或非限定名称。如果指定了数据库,它将从数据库中标识表/视图。否则,它首先尝试查找具有给定名称的临时视图,然后匹配当前数据库中的表/视图。注意,全局临时视图数据库在这里也是有效的。
  def table(tableName: String): DataFrame = {
    table(sessionState.sqlParser.parseTableIdentifier(tableName))
  }

  private[sql] def table(tableIdent: TableIdentifier): DataFrame = {
    Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent))
  }

  /* ----------------- *
   |  一些其他的         |
   * ----------------- */

//使用Spark执行SQL查询,并将结果作为“DataFrame”返回。
//用于SQL解析的dialect可以用'spark.sql.dialect'配置。
  def sql(sqlText: String): DataFrame = {
    Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText))
  }

 //返回一个[[DataFrameReader]],可用于将非流式数据作为“DataFrame”读入。
  def read: DataFrameReader = new DataFrameReader(self)

 //返回一个' DataStreamReader ',它可以作为' DataFrame '读取流数据。
  @InterfaceStability.Evolving
  def readStream: DataStreamReader = new DataStreamReader(self)

  //执行一些代码块,并将执行该块所用的时间打印到标准输出。这仅在Scala中可用,主要用于交互式测试和调试。
  def time[T](f: => T): T = {
    val start = System.nanoTime()
    val ret = f
    val end = System.nanoTime()
    // scalastyle:off println
    println(s"Time taken: ${(end - start) / 1000 / 1000} ms")
    // scalastyle:on println
    ret
  }

  //禁用样式检查器,使"implicit "对象可以以小写i开头
  @Experimental
  @InterfaceStability.Evolving
  object implicits extends SQLImplicits with Serializable {
    protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
  }
 
  //停止底层的“SparkContext”。
  def stop(): Unit = {
    sparkContext.stop()
  }

  //“stop()”的同义词
  override def close(): Unit = stop()

  //解析内部字符串表示中的数据类型。数据类型字符串的格式应与scala中的“toString”生成的格式相同,只有PysSpark才使用。
  protected[sql] def parseDataType(dataTypeString: String): DataType = {
    DataType.fromJson(dataTypeString)
  }

  //将schemaString定义的模式应用到RDD上。它只被PySpark使用。
  private[sql] def applySchemaToPythonRDD(
      rdd: RDD[Array[Any]],
      schemaString: String): DataFrame = {
    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
    applySchemaToPythonRDD(rdd, schema)
  }

 //将该模式定义的模式应用到RDD上。它只被PySpark使用。
  private[sql] def applySchemaToPythonRDD(
      rdd: RDD[Array[Any]],
      schema: StructType): DataFrame = {
    val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
    Dataset.ofRows(self, LogicalRDD(schema.toAttributes, rowRdd)(self))
  }

 //返回给定java bean类的Catalyst模式。
  private def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
    val (dataType, _) = JavaTypeInference.inferDataType(beanClass)
    dataType.asInstanceOf[StructType].fields.map { f =>
      AttributeReference(f.name, f.dataType, f.nullable)()
    }
  }

}


@InterfaceStability.Stable
object SparkSession {

  
  @InterfaceStability.Stable
  class Builder extends Logging {

    private[this] val options = new scala.collection.mutable.HashMap[String, String]

    private[this] val extensions = new SparkSessionExtensions

    private[this] var userSuppliedContext: Option[SparkContext] = None

    private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
      userSuppliedContext = Option(sparkContext)
      this
    }

   //设置应用程序的名称,该名称将显示在Spark web UI中。
   //如果没有设置应用程序名称,将使用随机生成的名称。
    def appName(name: String): Builder = config("spark.app.name", name)

    def config(key: String, value: String): Builder = synchronized {
      options += key -> value
      this
    }

    def config(key: String, value: Long): Builder = synchronized {
      options += key -> value.toString
      this
    }

    def config(key: String, value: Double): Builder = synchronized {
      options += key -> value.toString
      this
    }

    def config(key: String, value: Boolean): Builder = synchronized {
      options += key -> value.toString
      this
    }

    def config(conf: SparkConf): Builder = synchronized {
      conf.getAll.foreach { case (k, v) => options += k -> v }
      this
    }

    def master(master: String): Builder = config("spark.master", master)

    //启用配置单元支持,包括连接到永久性配置单元元存储、支持配置单元序列和配置单元用户定义函数。
    def enableHiveSupport(): Builder = synchronized {
      if (hiveClassesArePresent) {
        config(CATALOG_IMPLEMENTATION.key, "hive")
      } else {
        throw new IllegalArgumentException(
          "Unable to instantiate SparkSession with Hive support because " +
            "Hive classes are not found.")
      }
    }
      
	//将扩展注入[[SparkSession]]。这允许用户添加分析器规则、优化器规则、规划策略或自定义解析器。
    def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
      f(extensions)
      this
    }

    def getOrCreate(): SparkSession = synchronized {
      // Get the session from current thread's active session.
      var session = activeThreadSession.get()
      if ((session ne null) && !session.sparkContext.isStopped) {
        options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
        if (options.nonEmpty) {
          logWarning("Using an existing SparkSession; some configuration may not take effect.")
        }
        return session
      }

      // 全局同步,因此我们将只设置默认会话一次。
      SparkSession.synchronized {
        // 如果当前线程没有活动会话,则从全局会话获取它。
        session = defaultSession.get()
        if ((session ne null) && !session.sparkContext.isStopped) {
          options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
          if (options.nonEmpty) {
            logWarning("Using an existing SparkSession; some configuration may not take effect.")
          }
          return session
        }

        // 没有活跃或者全局会话,创建一个
        val sparkContext = userSuppliedContext.getOrElse {
          // 创建app name 如果没有
          val randomAppName = java.util.UUID.randomUUID().toString
          val sparkConf = new SparkConf()
          options.foreach { case (k, v) => sparkConf.set(k, v) }
          if (!sparkConf.contains("spark.app.name")) {
            sparkConf.setAppName(randomAppName)
          }
          val sc = SparkContext.getOrCreate(sparkConf)
          // 可能这是一个现有的SparkContext,更新它的SparkConf,它可能被SparkSession使用
          options.foreach { case (k, v) => sc.conf.set(k, v) }
          if (!sc.conf.contains("spark.app.name")) {
            sc.conf.setAppName(randomAppName)
          }
          sc
        }

        // 如果用户定义了配置器类,则初始化扩展。
        val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
        if (extensionConfOption.isDefined) {
          val extensionConfClassName = extensionConfOption.get
          try {
            val extensionConfClass = Utils.classForName(extensionConfClassName)
            val extensionConf = extensionConfClass.newInstance()
              .asInstanceOf[SparkSessionExtensions => Unit]
            extensionConf(extensions)
          } catch {
            // 如果找不到类或者类的类型错误,则忽略该错误。
            case e @ (_: ClassCastException |
                      _: ClassNotFoundException |
                      _: NoClassDefFoundError) =>
              logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
          }
        }

        session = new SparkSession(sparkContext, None, None, extensions)
        options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
        defaultSession.set(session)

       
        sparkContext.addSparkListener(new SparkListener {
          override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
            defaultSession.set(null)
            sqlListener.set(null)
          }
        })
      }

      return session
    }
  }

//创建一个[[SparkSession。用于构建一个[[SparkSession]]。
  def builder(): Builder = new Builder


  def setActiveSession(session: SparkSession): Unit = {
    activeThreadSession.set(session)
  }

  def clearActiveSession(): Unit = {
    activeThreadSession.remove()
  }

  def setDefaultSession(session: SparkSession): Unit = {
    defaultSession.set(session)
  }

  def clearDefaultSession(): Unit = {
    defaultSession.set(null)
  }

  def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)

  def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)

  /** A global SQL listener used for the SQL UI. */
  private[sql] val sqlListener = new AtomicReference[SQLListener]()

  
  // Private methods from now on
  

  /** The active SparkSession for the current thread. */
  private val activeThreadSession = new InheritableThreadLocal[SparkSession]

  /** Reference to the root SparkSession. */
  private val defaultSession = new AtomicReference[SparkSession]

  private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME =
    "org.apache.spark.sql.hive.HiveSessionStateBuilder"

  private def sessionStateClassName(conf: SparkConf): String = {
    conf.get(CATALOG_IMPLEMENTATION) match {
      case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME
      case "in-memory" => classOf[SessionStateBuilder].getCanonicalName
    }
  }

  private def instantiateSessionState(
      className: String,
      sparkSession: SparkSession): SessionState = {
    try {
      // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])`
      val clazz = Utils.classForName(className)
      val ctor = clazz.getConstructors.head
      ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
    } catch {
      case NonFatal(e) =>
        throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
    }
  }

  /**
   * @return true if Hive classes can be loaded, otherwise false.
   */
  private[spark] def hiveClassesArePresent: Boolean = {
    try {
      Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME)
      Utils.classForName("org.apache.hadoop.hive.conf.HiveConf")
      true
    } catch {
      case _: ClassNotFoundException | _: NoClassDefFoundError => false
    }
  }

}

 类似资料:

相关阅读

相关文章

相关问答