import org.apache.spark.rdd.RDD import org.apache.spark.graphx._ import scala.reflect.ClassTag /** Returns the subgraph of `graph` containing only `vertices` and their neighbors. */ def subgraphWithNeighbors[VD, ED: ClassTag, A: ClassTag]( graph: Graph[VD, ED], vertices: RDD[(VertexId, A)]): Graph[VD, ED] = { // Label each vertex in graph with true if it is a member of `vertices` and false if not val labeledGraph = graph.outerJoinVertices(vertices) { (id, oldAttr, isSampled) => isSampled.nonEmpty } // Propagate the labels to neighbors val neighbors = labeledGraph.aggregateMessages[Boolean](ctx => { if (ctx.srcAttr && !ctx.dstAttr) { ctx.sendToDst(true) } else if (ctx.dstAttr && !ctx.srcAttr) { ctx.sendToSrc(true) } }, _ || _) // Join the neighbors back to the graph val fullLabeledGraph = labeledGraph.outerJoinVertices(neighbors) { (id, isSampled, isNeighbor) => isSampled || isNeighbor.getOrElse(false) } // Drop all "false" vertices from the graph val subgraph = fullLabeledGraph.subgraph(vpred = (id, label) => label) // Recover the original vertex attributes graph.mask(subgraph) } // Example: val n = 100L val graph = Graph.fromEdgeTuples(sc.parallelize( (0L until n).map(x => (x, (x + 1) % n))), "v") val randomSet = graph.vertices.sample(false, 0.1, 1) randomSet.collect.sortBy(_._1).foreach(println) // (29,v) // (46,v) // (56,v) // (61,v) // (65,v) // (71,v) val subgraph = subgraphWithNeighbors(graph, randomSet) subgraph.vertices.collect.sortBy(_._1).foreach(println) // (28,v) // (29,v) // (30,v) // (45,v) // (46,v) // (47,v) // (55,v) // (56,v) // (57,v) // (60,v) // (61,v) // (62,v) // (64,v) // (65,v) // (66,v) // (70,v) // (71,v) // (72,v)