package spark.examples |
import java.util.Random |
import spark.SparkContext |
import spark.SparkContext._ |
import spark.examples.Vector._ |
object SparkKMeans { |
/** |
* line -> vector |
*/ |
def parseVector (line: String) : Vector = { |
return new Vector (line.split ( ' ' ).map (_.toDouble) ) |
} |
/** |
* 计算该节点的最近中心节点 |
*/ |
def closestCenter (p: Vector, centers: Array[Vector]) : Int = { |
var bestIndex = 0 |
var bestDist = p.squaredDist (centers ( 0 ) ) //差平方之和 |
for (i < - 1 until centers.length) { |
val dist = p.squaredDist (centers (i) ) |
if (dist < bestDist) { |
bestDist = dist |
bestIndex = i |
} |
} |
return bestIndex |
} |
def main (args: Array[String]) { |
if (args.length < 3 ) { |
System.err.println ( "Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>" ) |
System.exit ( 1 ) |
} |
val sc = new SparkContext (args ( 0 ), "SparkKMeans" ) |
val lines = sc.textFile (args ( 1 ), args ( 5 ).toInt) |
val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector |
val dimensions = args ( 2 ).toInt //节点的维度 |
val k = args ( 3 ).toInt //聚类个数 |
val iterations = args ( 4 ).toInt //迭代次数 |
// 随机初始化k个中心节点 |
val rand = new Random ( 42 ) |
var centers = new Array[Vector] (k) |
for (i < - 0 until k) |
centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1 ) |
println ( "Initial centers: " + centers.mkString ( ", " ) ) |
val time1 = System.currentTimeMillis() |
for (i < - 1 to iterations) { |
println ( "On iteration " + i) |
// Map each point to the index of its closest center and a (point, 1) pair |
// that we will use to compute an average later |
val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1 ) ) } |
val newCenters = mappedPoints.reduceByKey { |
case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加) |
} .map { |
case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类,重新计算中心节点的位置 |
} .collect |
// 更新中心节点 |
for ( (id, value) < - newCenters) { |
centers (id) = value |
} |
} |
val time2 = System.currentTimeMillis() |
println ( "Final centers: " + centers.mkString ( ", " ) + ", time: " + (time2 - time1) ) |
} |
} |