GraphX的图运算操作 - 聚合操作



1 aggregateMessages

1.1 aggregateMessages接口


  1. def aggregateMessages[A: ClassTag](
  2. sendMsg: EdgeContext[VD, ED, A] => Unit,
  3. mergeMsg: (A, A) => A,
  4. tripletFields: TripletFields = TripletFields.All)
  5. : VertexRDD[A] = {
  6. aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
  7. }


  • sendMsg: 发消息函数
  1. private def sendMsg(ctx: EdgeContext[KCoreVertex, Int, Map[Int, Int]]): Unit = {
  2. ctx.sendToDst(Map(ctx.srcAttr.preKCore -> -1, ctx.srcAttr.curKCore -> 1))
  3. ctx.sendToSrc(Map(ctx.dstAttr.preKCore -> -1, ctx.dstAttr.curKCore -> 1))
  4. }
  • mergeMsg:合并消息函数


  • tripletFields:定义发消息的方向

1.2 aggregateMessages处理流程


1.2.1 Map阶段

从入口函数进入aggregateMessagesWithActiveSet函数,该函数首先使用VertexRDD[VD]更新replicatedVertexView, 只更新其中vertexRDDattr对象。如构建图中介绍的,

  1. replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
  2. val view = activeSetOpt match {
  3. case Some((activeSet, _)) =>
  4. //返回只包含活跃顶点的replicatedVertexView
  5. replicatedVertexView.withActiveSet(activeSet)
  6. case None =>
  7. replicatedVertexView
  8. }


  1. val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap {
  2. case (pid, edgePartition) =>
  3. // 选择 scan 方法
  4. val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
  5. activeDirectionOpt match {
  6. case Some(EdgeDirection.Both) =>
  7. if (activeFraction < 0.8) {
  8. edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
  9. EdgeActiveness.Both)
  10. } else {
  11. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  12. EdgeActiveness.Both)
  13. }
  14. case Some(EdgeDirection.Either) =>
  15. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  16. EdgeActiveness.Either)
  17. case Some(EdgeDirection.Out) =>
  18. if (activeFraction < 0.8) {
  19. edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
  20. EdgeActiveness.SrcOnly)
  21. } else {
  22. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  23. EdgeActiveness.SrcOnly)
  24. }
  25. case Some(EdgeDirection.In) =>
  26. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  27. EdgeActiveness.DstOnly)
  28. case _ => // None
  29. edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
  30. EdgeActiveness.Neither)
  31. }
  32. })


  1. def aggregateMessagesEdgeScan[A: ClassTag](
  2. sendMsg: EdgeContext[VD, ED, A] => Unit,
  3. mergeMsg: (A, A) => A,
  4. tripletFields: TripletFields,
  5. activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
  6. var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
  7. var i = 0
  8. while (i < size) {
  9. val localSrcId = localSrcIds(i)
  10. val srcId = local2global(localSrcId)
  11. val localDstId = localDstIds(i)
  12. val dstId = local2global(localDstId)
  13. val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD]
  14. val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
  15. ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i))
  16. sendMsg(ctx)
  17. i += 1
  18. }


  • 获取顶点相关信息

在前文介绍edge partition时,我们知道它包含localSrcIds,localDstIds, data, index, global2local, local2global, vertexAttrs这几个重要的数据结构。其中localSrcIds,localDstIds分别表示源顶点、目的顶点在当前分区中的索引。

  • 发送消息


  1. override def sendToSrc(msg: A) {
  2. send(_localSrcId, msg)
  3. }
  4. override def sendToDst(msg: A) {
  5. send(_localDstId, msg)
  6. }
  7. @inline private def send(localId: Int, msg: A) {
  8. if (bitset.get(localId)) {
  9. aggregates(localId) = mergeMsg(aggregates(localId), msg)
  10. } else {
  11. aggregates(localId) = msg
  12. bitset.set(localId)
  13. }
  14. }

每个点之间在发消息的时候是独立的,即:点单纯根据方向,向以相邻点的以localId为下标的数组中插数据,互相独立,可以并行运行。Map阶段最后返回消息RDD messages: RDD[(VertexId, VD2)]



1.2.2 Reduce阶段


  1. vertices.aggregateUsingIndex(preAgg, mergeMsg)
  2. override def aggregateUsingIndex[VD2: ClassTag](
  3. messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
  4. val shuffled = messages.partitionBy(this.partitioner.get)
  5. val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
  6., reduceFunc))
  7. }
  8. this.withPartitionsRDD[VD2](parts)
  9. }


  • 1 对messages重新分区,分区器使用VertexRDDpartitioner。然后使用zipPartitions合并两个分区。

  • 2 对等合并attr, 聚合函数使用传入的mergeMsg函数

  1. def aggregateUsingIndex[VD2: ClassTag](
  2. iter: Iterator[Product2[VertexId, VD2]],
  3. reduceFunc: (VD2, VD2) => VD2): Self[VD2] = {
  4. val newMask = new BitSet(self.capacity)
  5. val newValues = new Array[VD2](self.capacity)
  6. iter.foreach { product =>
  7. val vid = product._1
  8. val vdata = product._2
  9. val pos = self.index.getPos(vid)
  10. if (pos >= 0) {
  11. if (newMask.get(pos)) {
  12. newValues(pos) = reduceFunc(newValues(pos), vdata)
  13. } else { // otherwise just store the new value
  14. newMask.set(pos)
  15. newValues(pos) = vdata
  16. }
  17. }
  18. }
  19. this.withValues(newValues).withMask(newMask)
  20. }





1.3 举例


  1. // Import random graph generation library
  2. import org.apache.spark.graphx.util.GraphGenerators
  3. // Create a graph with "age" as the vertex property. Here we use a random graph for simplicity.
  4. val graph: Graph[Double, Int] =
  5. GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
  6. // Compute the number of older followers and their total age
  7. val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)](
  8. triplet => { // Map Function
  9. if (triplet.srcAttr > triplet.dstAttr) {
  10. // Send message to destination vertex containing counter and age
  11. triplet.sendToDst(1, triplet.srcAttr)
  12. }
  13. },
  14. // Add counter and age
  15. (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function
  16. )
  17. // Divide total age by number of older followers to get average age of older followers
  18. val avgAgeOfOlderFollowers: VertexRDD[Double] =
  19. olderFollowers.mapValues( (id, value) => value match { case (count, totalAge) => totalAge / count } )
  20. // Display the results
  21. avgAgeOfOlderFollowers.collect.foreach(println(_))

2 collectNeighbors


  1. def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
  2. val nbrs = edgeDirection match {
  3. case EdgeDirection.Either =>
  4. graph.aggregateMessages[Array[(VertexId, VD)]](
  5. ctx => {
  6. ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
  7. ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
  8. },
  9. (a, b) => a ++ b, TripletFields.All)
  10. case EdgeDirection.In =>
  11. graph.aggregateMessages[Array[(VertexId, VD)]](
  12. ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
  13. (a, b) => a ++ b, TripletFields.Src)
  14. case EdgeDirection.Out =>
  15. graph.aggregateMessages[Array[(VertexId, VD)]](
  16. ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
  17. (a, b) => a ++ b, TripletFields.Dst)
  18. case EdgeDirection.Both =>
  19. throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
  20. "EdgeDirection.Either instead.")
  21. }
  22. graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
  23. nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
  24. }
  25. }


  1. ctx => {
  2. ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
  3. ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
  4. },



  1. (a, b) => a ++ b


3 collectNeighborIds


  1. def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
  2. val nbrs =
  3. if (edgeDirection == EdgeDirection.Either) {
  4. graph.aggregateMessages[Array[VertexId]](
  5. ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
  6. _ ++ _, TripletFields.None)
  7. } else if (edgeDirection == EdgeDirection.Out) {
  8. graph.aggregateMessages[Array[VertexId]](
  9. ctx => ctx.sendToSrc(Array(ctx.dstId)),
  10. _ ++ _, TripletFields.None)
  11. } else if (edgeDirection == EdgeDirection.In) {
  12. graph.aggregateMessages[Array[VertexId]](
  13. ctx => ctx.sendToDst(Array(ctx.srcId)),
  14. _ ++ _, TripletFields.None)
  15. } else {
  16. throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
  17. "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
  18. }
  19. graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
  20. nbrsOpt.getOrElse(Array.empty[VertexId])
  21. }
  22. }


  1. ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }

