GraphX的图运算操作 - 聚合操作

优质
小牛编辑
171浏览
2023-12-01

GraphX中提供的聚合操作有aggregateMessagescollectNeighborIdscollectNeighbors三个,其中aggregateMessagesGraphImpl中实现,collectNeighborIdscollectNeighbors
GraphOps中实现。下面分别介绍这几个方法。

1 aggregateMessages

1.1 aggregateMessages接口

aggregateMessagesGraphX最重要的API,用于替换mapReduceTriplets。目前mapReduceTriplets最终也是通过aggregateMessages来实现的。它主要功能是向邻边发消息,合并邻边收到的消息,返回messageRDD
aggregateMessages的接口如下:

  1. def aggregateMessages[A: ClassTag](
  2. sendMsg: EdgeContext[VD, ED, A] => Unit,
  3. mergeMsg: (A, A) => A,
  4. tripletFields: TripletFields = TripletFields.All)
  5. : VertexRDD[A] = {
  6. aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
  7. }

该接口有三个参数,分别为发消息函数,合并消息函数以及发消息的方向。

  • sendMsg: 发消息函数
  1. private def sendMsg(ctx: EdgeContext[KCoreVertex, Int, Map[Int, Int]]): Unit = {
  2. ctx.sendToDst(Map(ctx.srcAttr.preKCore -> -1, ctx.srcAttr.curKCore -> 1))
  3. ctx.sendToSrc(Map(ctx.dstAttr.preKCore -> -1, ctx.dstAttr.curKCore -> 1))
  4. }
  • mergeMsg:合并消息函数

该函数用于在Map阶段每个edge分区中每个点收到的消息合并,并且它还用于reduce阶段,合并不同分区的消息。合并vertexId相同的消息。

  • tripletFields:定义发消息的方向

1.2 aggregateMessages处理流程

aggregateMessages方法分为MapReduce两个阶段,下面我们分别就这两个阶段说明。

1.2.1 Map阶段

从入口函数进入aggregateMessagesWithActiveSet函数,该函数首先使用VertexRDD[VD]更新replicatedVertexView, 只更新其中vertexRDDattr对象。如构建图中介绍的,
replicatedVertexView是点和边的视图,点的属性有变化,要更新边中包含的点的attr

  1. replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
  2. val view = activeSetOpt match {
  3. case Some((activeSet, _)) =>
  4. //返回只包含活跃顶点的replicatedVertexView
  5. replicatedVertexView.withActiveSet(activeSet)
  6. case None =>
  7. replicatedVertexView
  8. }

程序然后会对replicatedVertexViewedgeRDDmapPartitions操作,所有的操作都在每个边分区的迭代中完成,如下面的代码:

  1. val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap {
  2. case (pid, edgePartition) =>
  3. // 选择 scan 方法
  4. val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
  5. activeDirectionOpt match {
  6. case Some(EdgeDirection.Both) =>
  7. if (activeFraction < 0.8) {
  8. edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
  9. EdgeActiveness.Both)
  10. } else {
  11. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  12. EdgeActiveness.Both)
  13. }
  14. case Some(EdgeDirection.Either) =>
  15. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  16. EdgeActiveness.Either)
  17. case Some(EdgeDirection.Out) =>
  18. if (activeFraction < 0.8) {
  19. edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
  20. EdgeActiveness.SrcOnly)
  21. } else {
  22. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  23. EdgeActiveness.SrcOnly)
  24. }
  25. case Some(EdgeDirection.In) =>
  26. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  27. EdgeActiveness.DstOnly)
  28. case _ => // None
  29. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  30. EdgeActiveness.Neither)
  31. }
  32. })

在分区内,根据activeFraction的大小选择是进入aggregateMessagesEdgeScan还是aggregateMessagesIndexScan处理。aggregateMessagesEdgeScan会顺序地扫描所有的边,
aggregateMessagesIndexScan会先过滤源顶点索引,然后在扫描。我们重点去分析aggregateMessagesEdgeScan

  1. def aggregateMessagesEdgeScan[A: ClassTag](
  2. sendMsg: EdgeContext[VD, ED, A] => Unit,
  3. mergeMsg: (A, A) => A,
  4. tripletFields: TripletFields,
  5. activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
  6. var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
  7. var i = 0
  8. while (i < size) {
  9. val localSrcId = localSrcIds(i)
  10. val srcId = local2global(localSrcId)
  11. val localDstId = localDstIds(i)
  12. val dstId = local2global(localDstId)
  13. val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD]
  14. val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
  15. ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i))
  16. sendMsg(ctx)
  17. i += 1
  18. }

该方法由两步组成,分别是获得顶点相关信息,以及发送消息。

  • 获取顶点相关信息

在前文介绍edge partition时,我们知道它包含localSrcIds,localDstIds, data, index, global2local, local2global, vertexAttrs这几个重要的数据结构。其中localSrcIds,localDstIds分别表示源顶点、目的顶点在当前分区中的索引。
所以我们可以遍历localSrcIds,根据其下标去localSrcIds中拿到srcId在全局local2global中的索引,最后拿到srcId。通过vertexAttrs拿到顶点属性。通过data拿到边属性。

  • 发送消息

发消息前会根据接口中定义的tripletFields,拿到发消息的方向。发消息的过程就是遍历到一条边,向localSrcIds/localDstIds中添加数据,如果localSrcIds/localDstIds中已经存在该数据,则执行合并函数mergeMsg

  1. override def sendToSrc(msg: A) {
  2. send(_localSrcId, msg)
  3. }
  4. override def sendToDst(msg: A) {
  5. send(_localDstId, msg)
  6. }
  7. @inline private def send(localId: Int, msg: A) {
  8. if (bitset.get(localId)) {
  9. aggregates(localId) = mergeMsg(aggregates(localId), msg)
  10. } else {
  11. aggregates(localId) = msg
  12. bitset.set(localId)
  13. }
  14. }

每个点之间在发消息的时候是独立的,即:点单纯根据方向,向以相邻点的以localId为下标的数组中插数据,互相独立,可以并行运行。Map阶段最后返回消息RDD messages: RDD[(VertexId, VD2)]

Map阶段的执行流程如下例所示:

graphx_aggmsg_map

1.2.2 Reduce阶段

Reduce阶段的实现就是调用下面的代码

  1. vertices.aggregateUsingIndex(preAgg, mergeMsg)
  2. override def aggregateUsingIndex[VD2: ClassTag](
  3. messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
  4. val shuffled = messages.partitionBy(this.partitioner.get)
  5. val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
  6. thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
  7. }
  8. this.withPartitionsRDD[VD2](parts)
  9. }

上面的代码通过两步实现。

  • 1 对messages重新分区,分区器使用VertexRDDpartitioner。然后使用zipPartitions合并两个分区。

  • 2 对等合并attr, 聚合函数使用传入的mergeMsg函数

  1. def aggregateUsingIndex[VD2: ClassTag](
  2. iter: Iterator[Product2[VertexId, VD2]],
  3. reduceFunc: (VD2, VD2) => VD2): Self[VD2] = {
  4. val newMask = new BitSet(self.capacity)
  5. val newValues = new Array[VD2](self.capacity)
  6. iter.foreach { product =>
  7. val vid = product._1
  8. val vdata = product._2
  9. val pos = self.index.getPos(vid)
  10. if (pos >= 0) {
  11. if (newMask.get(pos)) {
  12. newValues(pos) = reduceFunc(newValues(pos), vdata)
  13. } else { // otherwise just store the new value
  14. newMask.set(pos)
  15. newValues(pos) = vdata
  16. }
  17. }
  18. }
  19. this.withValues(newValues).withMask(newMask)
  20. }

根据传参,我们知道上面的代码迭代的是messagePartition,并不是每个节点都会收到消息,所以messagePartition集合最小,迭代速度会快。

这段代码表示,我们根据vetexIdindex中取到其下标pos,再根据下标,从values中取到attr,存在attr就用mergeMsg合并attr,不存在就直接赋值。

Reduce阶段的过程如下图所示:

graphx_aggmsg_map

1.3 举例

下面的例子计算比用户年龄大的追随者(即followers)的平均年龄。

  1. // Import random graph generation library
  2. import org.apache.spark.graphx.util.GraphGenerators
  3. // Create a graph with "age" as the vertex property. Here we use a random graph for simplicity.
  4. val graph: Graph[Double, Int] =
  5. GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
  6. // Compute the number of older followers and their total age
  7. val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)](
  8. triplet => { // Map Function
  9. if (triplet.srcAttr > triplet.dstAttr) {
  10. // Send message to destination vertex containing counter and age
  11. triplet.sendToDst(1, triplet.srcAttr)
  12. }
  13. },
  14. // Add counter and age
  15. (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function
  16. )
  17. // Divide total age by number of older followers to get average age of older followers
  18. val avgAgeOfOlderFollowers: VertexRDD[Double] =
  19. olderFollowers.mapValues( (id, value) => value match { case (count, totalAge) => totalAge / count } )
  20. // Display the results
  21. avgAgeOfOlderFollowers.collect.foreach(println(_))

2 collectNeighbors

该方法的作用是收集每个顶点的邻居顶点的顶点id和顶点属性。

  1. def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
  2. val nbrs = edgeDirection match {
  3. case EdgeDirection.Either =>
  4. graph.aggregateMessages[Array[(VertexId, VD)]](
  5. ctx => {
  6. ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
  7. ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
  8. },
  9. (a, b) => a ++ b, TripletFields.All)
  10. case EdgeDirection.In =>
  11. graph.aggregateMessages[Array[(VertexId, VD)]](
  12. ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
  13. (a, b) => a ++ b, TripletFields.Src)
  14. case EdgeDirection.Out =>
  15. graph.aggregateMessages[Array[(VertexId, VD)]](
  16. ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
  17. (a, b) => a ++ b, TripletFields.Dst)
  18. case EdgeDirection.Both =>
  19. throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
  20. "EdgeDirection.Either instead.")
  21. }
  22. graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
  23. nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
  24. }
  25. }

从上面的代码中,第一步是根据EdgeDirection来确定调用哪个aggregateMessages实现聚合操作。我们用满足条件EdgeDirection.Either的情况来说明。可以看到aggregateMessages的方式消息的函数为:

  1. ctx => {
  2. ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
  3. ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
  4. },

这个函数在处理每条边时都会同时向源顶点和目的顶点发送消息,消息内容分别为(目的顶点id,目的顶点属性)(源顶点id,源顶点属性)。为什么会这样处理呢?
我们知道,每条边都由两个顶点组成,对于这个边,我需要向源顶点发送目的顶点的信息来记录它们之间的邻居关系,同理向目的顶点发送源顶点的信息来记录它们之间的邻居关系。

Merge函数是一个集合合并操作,它合并同同一个顶点对应的所有目的顶点的信息。如下所示:

  1. (a, b) => a ++ b

通过aggregateMessages获得包含邻居关系信息的VertexRDD后,把它和现有的verticesjoin操作,得到每个顶点的邻居消息。

3 collectNeighborIds

该方法的作用是收集每个顶点的邻居顶点的顶点id。它的实现和collectNeighbors非常相同。

  1. def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
  2. val nbrs =
  3. if (edgeDirection == EdgeDirection.Either) {
  4. graph.aggregateMessages[Array[VertexId]](
  5. ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
  6. _ ++ _, TripletFields.None)
  7. } else if (edgeDirection == EdgeDirection.Out) {
  8. graph.aggregateMessages[Array[VertexId]](
  9. ctx => ctx.sendToSrc(Array(ctx.dstId)),
  10. _ ++ _, TripletFields.None)
  11. } else if (edgeDirection == EdgeDirection.In) {
  12. graph.aggregateMessages[Array[VertexId]](
  13. ctx => ctx.sendToDst(Array(ctx.srcId)),
  14. _ ++ _, TripletFields.None)
  15. } else {
  16. throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
  17. "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
  18. }
  19. graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
  20. nbrsOpt.getOrElse(Array.empty[VertexId])
  21. }
  22. }

collectNeighbors的实现不同的是,aggregateMessages函数中的sendMsg函数只发送顶点Id到源顶点和目的顶点。其它的实现基本一致。

  1. ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }

4 参考文献

【1】Graphx:构建graph和聚合消息

【2】spark源码