用户注册



邮箱:

密码:

用户登录


邮箱:

密码:
记住登录一个月忘记密码?

发表随想


还能输入:200字
云代码 - java代码库

K-means算法(Spark Demo)

2013-04-28 作者: 神马举报

[java]代码库

package spark.examples

import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._

object SparkKMeans {
    /**
     * line -> vector
     */
def parseVector (line: String) : Vector = {
        return new Vector (line.split (' ').map (_.toDouble) )
    }

    /**
     * 计算该节点的最近中心节点
     */
def closestCenter (p: Vector, centers: Array[Vector]) : Int = {
        var bestIndex = 0
        var bestDist = p.squaredDist (centers (0) ) //差平方之和
        for (i < - 1 until centers.length) {
            val dist = p.squaredDist (centers (i) )
            if (dist < bestDist) {
                bestDist = dist
                bestIndex = i
            }
        }
        return bestIndex
    }

def main (args: Array[String]) {
        if (args.length < 3) {
            System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
            System.exit (1)
        }
        val sc = new SparkContext (args (0), "SparkKMeans")
        val lines = sc.textFile (args (1), args (5).toInt)
                    val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector
                                 val dimensions = args (2).toInt //节点的维度
                                         val k = args (3).toInt //聚类个数
                                                 val iterations = args (4).toInt //迭代次数

                                                         // 随机初始化k个中心节点
                                                         val rand = new Random (42)
        var centers = new Array[Vector] (k)
        for (i < - 0 until k)
            centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)
                          println ("Initial centers: " + centers.mkString (", ") )
                          val time1 = System.currentTimeMillis()
            for (i < - 1 to iterations) {
                println ("On iteration " + i)

                // Map each point to the index of its closest center and a (point, 1) pair
                // that we will use to compute an average later
                val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }

                val newCenters = mappedPoints.reduceByKey {
                case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)
                    } .map {
                case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类,重新计算中心节点的位置
                    } .collect

                // 更新中心节点
                for ( (id, value) < - newCenters) {
                    centers (id) = value
                }
            }
                       val time2 = System.currentTimeMillis()
                                   println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )
    }
}


网友评论    (发表评论)


发表评论:

评论须知:

  • 1、评论每次加2分,每天上限为30;
  • 2、请文明用语,共同创建干净的技术交流环境;
  • 3、若被发现提交非法信息,评论将会被删除,并且给予扣分处理,严重者给予封号处理;
  • 4、请勿发布广告信息或其他无关评论,否则将会删除评论并扣分,严重者给予封号处理。


扫码下载

加载中,请稍后...

输入口令后可复制整站源码

加载中,请稍后...