001/* Train a KMeans model in MLlib. 002 * 003 * Copyright (c) 2014 The Regents of the University of California. 004 * All rights reserved. 005 * 006 * '$Author: crawl $' 007 * '$Date: 2018-02-06 19:09:58 +0000 (Tue, 06 Feb 2018) $' 008 * '$Revision: 34656 $' 009 * 010 * Permission is hereby granted, without written agreement and without 011 * license or royalty fees, to use, copy, modify, and distribute this 012 * software and its documentation for any purpose, provided that the above 013 * copyright notice and the following two paragraphs appear in all copies 014 * of this software. 015 * 016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 020 * SUCH DAMAGE. 021 * 022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 027 * ENHANCEMENTS, OR MODIFICATIONS. 028 * 029 */ 030package org.kepler.spark.mllib; 031 032import java.util.ArrayList; 033import java.util.List; 034import java.util.Map; 035 036import org.apache.spark.api.java.JavaRDD; 037import org.apache.spark.mllib.clustering.KMeans; 038import org.apache.spark.mllib.linalg.Vector; 039import org.apache.spark.rdd.RDD; 040 041import ptolemy.actor.TypedAtomicActor; 042import ptolemy.actor.TypedIOPort; 043import ptolemy.actor.parameters.PortParameter; 044import ptolemy.data.ArrayToken; 045import ptolemy.data.DoubleToken; 046import ptolemy.data.IntToken; 047import ptolemy.data.LongToken; 048import ptolemy.data.ObjectToken; 049import ptolemy.data.Token; 050import ptolemy.data.expr.StringParameter; 051import ptolemy.data.type.ArrayType; 052import ptolemy.data.type.BaseType; 053import ptolemy.data.type.ObjectType; 054import ptolemy.kernel.CompositeEntity; 055import ptolemy.kernel.util.IllegalActionException; 056import ptolemy.kernel.util.NameDuplicationException; 057import ptolemy.kernel.util.SingletonAttribute; 058 059/** Train a KMeans model in Spark MLlib. 060 * 061 * @author Daniel Crawl 062 * @version $Id: KMeansModel.java 34656 2018-02-06 19:09:58Z crawl $ 063 * 064 */ 065public class KMeansModel extends TypedAtomicActor { 066 067 public KMeansModel(CompositeEntity container, String name) 068 throws IllegalActionException, NameDuplicationException { 069 super(container, name); 070 071 data = new TypedIOPort(this, "data", true, false); 072 data.setTypeEquals(new ObjectType(JavaRDD.class)); 073 new SingletonAttribute(data, "_showName"); 074 075 numClusters = new PortParameter(this, "numClusters"); 076 numClusters.setTypeEquals(BaseType.INT); 077 numClusters.getPort().setTypeEquals(BaseType.INT); 078 new SingletonAttribute(numClusters.getPort(), "_showName"); 079 080 numRuns = new PortParameter(this, "numRuns"); 081 numRuns.setTypeEquals(BaseType.INT); 082 numRuns.getPort().setTypeEquals(BaseType.INT); 083 new SingletonAttribute(numRuns.getPort(), "_showName"); 084 numRuns.setToken(IntToken.ONE); 085 086 iterations = new PortParameter(this, "iterations"); 087 iterations.setTypeEquals(BaseType.INT); 088 iterations.getPort().setTypeEquals(BaseType.INT); 089 new SingletonAttribute(iterations.getPort(), "_showName"); 090 iterations.setToken(new IntToken(10)); 091 092 seed = new PortParameter(this, "seed"); 093 seed.setTypeEquals(BaseType.INT); 094 seed.getPort().setTypeEquals(BaseType.INT); 095 new SingletonAttribute(seed.getPort(), "_showName"); 096 097 error = new TypedIOPort(this, "error", false, true); 098 error.setTypeEquals(BaseType.DOUBLE); 099 new SingletonAttribute(error, "_showName"); 100 101 centers = new TypedIOPort(this, "centers", false, true); 102 centers.setTypeEquals(new ArrayType(new ArrayType(BaseType.DOUBLE))); 103 new SingletonAttribute(centers, "_showName"); 104 105 clusterSizes = new TypedIOPort(this, "clusterSizes", false, true); 106 clusterSizes.setTypeEquals(new ArrayType(BaseType.LONG)); 107 new SingletonAttribute(clusterSizes, "_showName"); 108 109 initSteps = new PortParameter(this, "initSteps"); 110 initSteps.setTypeEquals(BaseType.INT); 111 initSteps.getPort().setTypeEquals(BaseType.INT); 112 new SingletonAttribute(initSteps.getPort(), "_showName"); 113 initSteps.setToken(new IntToken(5)); 114 115 initializationMode = new StringParameter(this, "initializationMode"); 116 } 117 118 @Override 119 public void fire() throws IllegalActionException { 120 121 super.fire(); 122 123 iterations.update(); 124 final int numIterations = ((IntToken)iterations.getToken()).intValue(); 125 126 numClusters.update(); 127 final int numClustersVal = ((IntToken)numClusters.getToken()).intValue(); 128 129 numRuns.update(); 130 final int numRunsVal = ((IntToken)numRuns.getToken()).intValue(); 131 132 seed.update(); 133 final int seedVal = ((IntToken)seed.getToken()).intValue(); 134 135 initSteps.update(); 136 final int numInitSteps = ((IntToken)initSteps.getToken()).intValue(); 137 138 final String initMode = initializationMode.getValueAsString(); 139 140 final KMeans kmeans = new KMeans(); 141 kmeans.setMaxIterations(numIterations); 142 kmeans.setK(numClustersVal); 143 kmeans.setSeed(seedVal); 144 kmeans.setRuns(numRunsVal); 145 kmeans.setInitializationMode(initMode); 146 kmeans.setInitializationSteps(numInitSteps); 147 148 final JavaRDD<Vector> javaRDD = (JavaRDD<Vector>) ((ObjectToken)data.get(0)).getValue(); 149 final RDD<Vector> rdd = javaRDD.rdd(); 150 151 _model = kmeans.run(rdd.cache()); 152 153 final double errorVal = _model.computeCost(rdd); 154 error.broadcast(new DoubleToken(errorVal)); 155 156 Vector[] centerVectors = _model.clusterCenters(); 157 ArrayList<ArrayToken> centersArray = new ArrayList<ArrayToken>(centerVectors.length); 158 for(Vector vector : centerVectors) { 159 ArrayList<Token> center = new ArrayList<Token>(vector.size()); 160 for(double val : vector.toArray()) { 161 center.add(new DoubleToken(val)); 162 } 163 centersArray.add(new ArrayToken(center.toArray(new Token[vector.size()]))); 164 } 165 centers.broadcast(new ArrayToken(centersArray.toArray(new ArrayToken[centerVectors.length]))); 166 167 JavaRDD<Integer> prediction =_model.predict(javaRDD); 168 Map<Integer,Long> counts = prediction.countByValue(); 169 List<Token> countsArray = new ArrayList<Token>(); 170 for(Long count : counts.values()) { 171 countsArray.add(new LongToken(count.longValue())); 172 } 173 clusterSizes.broadcast(new ArrayToken(countsArray.toArray(new Token[countsArray.size()]))); 174 175 } 176 177 /** The input vectors. */ 178 public TypedIOPort data; 179 180 /** The number of clusters. */ 181 public PortParameter numClusters; 182 183 /** The number of runs of the algorithm to execute in parallel. */ 184 public PortParameter numRuns; 185 186 /** The maximum number of iterations to run. */ 187 public PortParameter iterations; 188 189 /** The random seed value to use for cluster initialization . */ 190 public PortParameter seed; 191 192 public StringParameter initializationMode; 193 194 public PortParameter initSteps; 195 196 /** The sum of squared distances to their nearest center. */ 197 public TypedIOPort error; 198 199 /** The center of the clusters. */ 200 public TypedIOPort centers; 201 202 /** The size of each cluster. */ 203 public TypedIOPort clusterSizes; 204 205 206 /** The MLlib KMeans model. */ 207 private org.apache.spark.mllib.clustering.KMeansModel _model; 208 209}