图算法实现 - 三角计数

优质
小牛编辑
170浏览
2023-12-01
  1. import scala.reflect.ClassTag
  2. import org.apache.spark.graphx._
  3. /**
  4. * Compute the number of triangles passing through each vertex.
  5. *
  6. * The algorithm is relatively straightforward and can be computed in three steps:
  7. *
  8. * <ul>
  9. * <li> Compute the set of neighbors for each vertex</li>
  10. * <li> For each edge compute the intersection of the sets and send the count to both vertices.</li>
  11. * <li> Compute the sum at each vertex and divide by two since each triangle is counted twice.</li>
  12. * </ul>
  13. *
  14. * There are two implementations. The default `TriangleCount.run` implementation first removes
  15. * self cycles and canonicalizes the graph to ensure that the following conditions hold:
  16. * <ul>
  17. * <li> There are no self edges</li>
  18. * <li> All edges are oriented src > dst</li>
  19. * <li> There are no duplicate edges</li>
  20. * </ul>
  21. * However, the canonicalization procedure is costly as it requires repartitioning the graph.
  22. * If the input data is already in "canonical form" with self cycles removed then the
  23. * `TriangleCount.runPreCanonicalized` should be used instead.
  24. *
  25. * {{{
  26. * val canonicalGraph = graph.mapEdges(e => 1).removeSelfEdges().canonicalizeEdges()
  27. * val counts = TriangleCount.runPreCanonicalized(canonicalGraph).vertices
  28. * }}}
  29. *
  30. */
  31. object TriangleCount {
  32. def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = {
  33. // Transform the edge data something cheap to shuffle and then canonicalize
  34. val canonicalGraph = graph.mapEdges(e => true).removeSelfEdges().convertToCanonicalEdges()
  35. // Get the triangle counts
  36. val counters = runPreCanonicalized(canonicalGraph).vertices
  37. // Join them bath with the original graph
  38. graph.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) =>
  39. optCounter.getOrElse(0)
  40. }
  41. }
  42. def runPreCanonicalized[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = {
  43. // Construct set representations of the neighborhoods
  44. val nbrSets: VertexRDD[VertexSet] =
  45. graph.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
  46. val set = new VertexSet(nbrs.length)
  47. var i = 0
  48. while (i < nbrs.length) {
  49. // prevent self cycle
  50. if (nbrs(i) != vid) {
  51. set.add(nbrs(i))
  52. }
  53. i += 1
  54. }
  55. set
  56. }
  57. // join the sets with the graph
  58. val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) {
  59. (vid, _, optSet) => optSet.getOrElse(null)
  60. }
  61. // Edge function computes intersection of smaller vertex with larger vertex
  62. def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) {
  63. val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) {
  64. (ctx.srcAttr, ctx.dstAttr)
  65. } else {
  66. (ctx.dstAttr, ctx.srcAttr)
  67. }
  68. val iter = smallSet.iterator
  69. var counter: Int = 0
  70. while (iter.hasNext) {
  71. val vid = iter.next()
  72. if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) {
  73. counter += 1
  74. }
  75. }
  76. ctx.sendToSrc(counter)
  77. ctx.sendToDst(counter)
  78. }
  79. // compute the intersection along edges
  80. val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _)
  81. // Merge counters with the graph and divide by two since each triangle is counted twice
  82. graph.outerJoinVertices(counters) { (_, _, optCounter: Option[Int]) =>
  83. val dblCount = optCounter.getOrElse(0)
  84. // This algorithm double counts each triangle so the final count should be even
  85. require(dblCount % 2 == 0, "Triangle count resulted in an invalid number of triangles.")
  86. dblCount / 2
  87. }
  88. }
  89. }