001/* Train a KMeans model in MLlib.
002 * 
003 * Copyright (c) 2014 The Regents of the University of California.
004 * All rights reserved.
005 *
006 * '$Author: crawl $'
007 * '$Date: 2018-02-06 19:09:58 +0000 (Tue, 06 Feb 2018) $' 
008 * '$Revision: 34656 $'
009 * 
010 * Permission is hereby granted, without written agreement and without
011 * license or royalty fees, to use, copy, modify, and distribute this
012 * software and its documentation for any purpose, provided that the above
013 * copyright notice and the following two paragraphs appear in all copies
014 * of this software.
015 *
016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
020 * SUCH DAMAGE.
021 *
022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
027 * ENHANCEMENTS, OR MODIFICATIONS.
028 *
029 */
030package org.kepler.spark.mllib;
031
032import java.util.ArrayList;
033import java.util.List;
034import java.util.Map;
035
036import org.apache.spark.api.java.JavaRDD;
037import org.apache.spark.mllib.clustering.KMeans;
038import org.apache.spark.mllib.linalg.Vector;
039import org.apache.spark.rdd.RDD;
040
041import ptolemy.actor.TypedAtomicActor;
042import ptolemy.actor.TypedIOPort;
043import ptolemy.actor.parameters.PortParameter;
044import ptolemy.data.ArrayToken;
045import ptolemy.data.DoubleToken;
046import ptolemy.data.IntToken;
047import ptolemy.data.LongToken;
048import ptolemy.data.ObjectToken;
049import ptolemy.data.Token;
050import ptolemy.data.expr.StringParameter;
051import ptolemy.data.type.ArrayType;
052import ptolemy.data.type.BaseType;
053import ptolemy.data.type.ObjectType;
054import ptolemy.kernel.CompositeEntity;
055import ptolemy.kernel.util.IllegalActionException;
056import ptolemy.kernel.util.NameDuplicationException;
057import ptolemy.kernel.util.SingletonAttribute;
058
059/** Train a KMeans model in Spark MLlib.
060 * 
061 *  @author Daniel Crawl
062 *  @version $Id: KMeansModel.java 34656 2018-02-06 19:09:58Z crawl $
063 * 
064 */
065public class KMeansModel extends TypedAtomicActor {
066    
067    public KMeansModel(CompositeEntity container, String name)
068    throws IllegalActionException, NameDuplicationException {
069        super(container, name);
070        
071        data = new TypedIOPort(this, "data", true, false);
072        data.setTypeEquals(new ObjectType(JavaRDD.class));
073        new SingletonAttribute(data, "_showName");
074        
075        numClusters = new PortParameter(this, "numClusters");
076        numClusters.setTypeEquals(BaseType.INT);
077        numClusters.getPort().setTypeEquals(BaseType.INT);
078        new SingletonAttribute(numClusters.getPort(), "_showName");
079
080        numRuns = new PortParameter(this, "numRuns");
081        numRuns.setTypeEquals(BaseType.INT);
082        numRuns.getPort().setTypeEquals(BaseType.INT);
083        new SingletonAttribute(numRuns.getPort(), "_showName");
084        numRuns.setToken(IntToken.ONE);
085
086        iterations = new PortParameter(this, "iterations");
087        iterations.setTypeEquals(BaseType.INT);
088        iterations.getPort().setTypeEquals(BaseType.INT);
089        new SingletonAttribute(iterations.getPort(), "_showName");
090        iterations.setToken(new IntToken(10));
091        
092        seed = new PortParameter(this, "seed");
093        seed.setTypeEquals(BaseType.INT);
094        seed.getPort().setTypeEquals(BaseType.INT);
095        new SingletonAttribute(seed.getPort(), "_showName");
096        
097        error = new TypedIOPort(this, "error", false, true);
098        error.setTypeEquals(BaseType.DOUBLE);
099        new SingletonAttribute(error, "_showName");
100        
101        centers = new TypedIOPort(this, "centers", false, true);
102        centers.setTypeEquals(new ArrayType(new ArrayType(BaseType.DOUBLE)));
103        new SingletonAttribute(centers, "_showName");
104
105        clusterSizes = new TypedIOPort(this, "clusterSizes", false, true);
106        clusterSizes.setTypeEquals(new ArrayType(BaseType.LONG));
107        new SingletonAttribute(clusterSizes, "_showName");
108
109        initSteps = new PortParameter(this, "initSteps");
110        initSteps.setTypeEquals(BaseType.INT);
111        initSteps.getPort().setTypeEquals(BaseType.INT);
112        new SingletonAttribute(initSteps.getPort(), "_showName");
113        initSteps.setToken(new IntToken(5));
114
115        initializationMode = new StringParameter(this, "initializationMode");
116    }
117    
118    @Override
119    public void fire() throws IllegalActionException {
120
121        super.fire();
122        
123        iterations.update();
124        final int numIterations = ((IntToken)iterations.getToken()).intValue();
125        
126        numClusters.update();
127        final int numClustersVal = ((IntToken)numClusters.getToken()).intValue();
128
129        numRuns.update();
130        final int numRunsVal = ((IntToken)numRuns.getToken()).intValue();
131        
132        seed.update();
133        final int seedVal = ((IntToken)seed.getToken()).intValue();
134
135            initSteps.update();
136            final int numInitSteps = ((IntToken)initSteps.getToken()).intValue();
137        
138            final String initMode = initializationMode.getValueAsString();
139        
140        final KMeans kmeans = new KMeans();
141        kmeans.setMaxIterations(numIterations);
142        kmeans.setK(numClustersVal);
143        kmeans.setSeed(seedVal);
144        kmeans.setRuns(numRunsVal);
145        kmeans.setInitializationMode(initMode);
146            kmeans.setInitializationSteps(numInitSteps);
147
148        final JavaRDD<Vector> javaRDD = (JavaRDD<Vector>) ((ObjectToken)data.get(0)).getValue();
149        final RDD<Vector> rdd = javaRDD.rdd();
150        
151        _model = kmeans.run(rdd.cache());
152        
153        final double errorVal = _model.computeCost(rdd);        
154        error.broadcast(new DoubleToken(errorVal));
155
156        Vector[] centerVectors = _model.clusterCenters();
157        ArrayList<ArrayToken> centersArray = new ArrayList<ArrayToken>(centerVectors.length);
158        for(Vector vector : centerVectors) {
159            ArrayList<Token> center = new ArrayList<Token>(vector.size());
160            for(double val : vector.toArray()) {
161                center.add(new DoubleToken(val));
162            }
163            centersArray.add(new ArrayToken(center.toArray(new Token[vector.size()])));
164        }
165        centers.broadcast(new ArrayToken(centersArray.toArray(new ArrayToken[centerVectors.length])));
166        
167        JavaRDD<Integer> prediction =_model.predict(javaRDD);
168        Map<Integer,Long> counts = prediction.countByValue();
169        List<Token> countsArray = new ArrayList<Token>();
170        for(Long count : counts.values()) {
171            countsArray.add(new LongToken(count.longValue()));
172        }
173        clusterSizes.broadcast(new ArrayToken(countsArray.toArray(new Token[countsArray.size()])));
174        
175    }
176    
177    /** The input vectors. */ 
178    public TypedIOPort data;
179        
180    /** The number of clusters. */
181    public PortParameter numClusters;
182    
183    /** The number of runs of the algorithm to execute in parallel. */
184    public PortParameter numRuns;
185    
186    /** The maximum number of iterations to run. */
187    public PortParameter iterations;
188    
189    /** The random seed value to use for cluster initialization . */
190    public PortParameter seed;
191
192    public StringParameter initializationMode;  
193
194    public PortParameter initSteps; 
195  
196    /** The sum of squared distances to their nearest center. */
197    public TypedIOPort error;
198    
199    /** The center of the clusters. */
200    public TypedIOPort centers;
201    
202    /** The size of each cluster. */
203    public TypedIOPort clusterSizes;
204    
205
206    /** The MLlib KMeans model. */
207    private org.apache.spark.mllib.clustering.KMeansModel _model;
208
209}