001/* 
002 * Copyright (c) 2016-2017 The Regents of the University of California.
003 * All rights reserved.
004 *
005 * '$Author: crawl $'
006 * '$Date: 2018-02-07 01:00:17 +0000 (Wed, 07 Feb 2018) $' 
007 * '$Revision: 34658 $'
008 * 
009 * Permission is hereby granted, without written agreement and without
010 * license or royalty fees, to use, copy, modify, and distribute this
011 * software and its documentation for any purpose, provided that the above
012 * copyright notice and the following two paragraphs appear in all copies
013 * of this software.
014 *
015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
019 * SUCH DAMAGE.
020 *
021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
026 * ENHANCEMENTS, OR MODIFICATIONS.
027 *
028 */
029
030package org.kepler.spark.mllib;
031
032import java.io.BufferedReader;
033import java.io.FileInputStream;
034import java.io.InputStreamReader;
035import java.io.PrintWriter;
036import java.util.Arrays;
037import java.util.Set;
038
039import org.apache.spark.ml.clustering.KMeans;
040import org.apache.spark.ml.clustering.KMeansModel;
041import org.apache.spark.ml.clustering.KMeansSummary;
042import org.apache.spark.ml.linalg.Vector;
043import org.apache.spark.sql.Dataset;
044import org.apache.spark.sql.Row;
045
046import io.vertx.core.json.JsonArray;
047import io.vertx.core.json.JsonObject;
048import ptolemy.actor.TypedAtomicActor;
049import ptolemy.actor.TypedIOPort;
050import ptolemy.actor.parameters.PortParameter;
051import ptolemy.data.ArrayToken;
052import ptolemy.data.DoubleToken;
053import ptolemy.data.IntToken;
054import ptolemy.data.LongToken;
055import ptolemy.data.ObjectToken;
056import ptolemy.data.StringToken;
057import ptolemy.data.Token;
058import ptolemy.data.expr.StringParameter;
059import ptolemy.data.type.ArrayType;
060import ptolemy.data.type.BaseType;
061import ptolemy.data.type.ObjectType;
062import ptolemy.kernel.CompositeEntity;
063import ptolemy.kernel.util.IllegalActionException;
064import ptolemy.kernel.util.NameDuplicationException;
065import ptolemy.kernel.util.SingletonAttribute;
066
067/**
068 * @author Dylan Uys, Jiaxin Li
069 *
070 * This actor calls Spark's KMeans API to perform clustering on the input data
071 * frame, specifically on a standardized feature column specified by the user.
072 * The cluster centers are output as a JSON string on the output port.  
073 */
074public class KMeansClustering extends TypedAtomicActor {
075   
076    public KMeansClustering(CompositeEntity container, String name)
077        throws IllegalActionException, NameDuplicationException {
078        super(container, name);
079
080        
081        data = new TypedIOPort(this, "data", true, false);
082        data.setTypeEquals(new ObjectType(Dataset.class));
083        new SingletonAttribute(data, "_showName");
084        
085        numClusters = new PortParameter(this, "numClusters");
086        numClusters.setTypeEquals(BaseType.INT);
087        numClusters.getPort().setTypeEquals(BaseType.INT);
088        new SingletonAttribute(numClusters.getPort(), "_showName");
089
090        iterations = new PortParameter(this, "iterations");
091        iterations.setTypeEquals(BaseType.INT);
092        iterations.getPort().setTypeEquals(BaseType.INT);
093        new SingletonAttribute(iterations.getPort(), "_showName");
094        iterations.setToken(new IntToken(10));
095        
096        seed = new PortParameter(this, "seed");
097        seed.setTypeEquals(BaseType.INT);
098        seed.getPort().setTypeEquals(BaseType.INT);
099        new SingletonAttribute(seed.getPort(), "_showName");
100        
101        error = new TypedIOPort(this, "error", false, true);
102        error.setTypeEquals(BaseType.DOUBLE);
103        new SingletonAttribute(error, "_showName");
104        
105        centers = new TypedIOPort(this, "centers", false, true);
106        centers.setTypeEquals(BaseType.STRING);
107        new SingletonAttribute(centers, "_showName");
108
109        clusterSizes = new TypedIOPort(this, "clusterSizes", false, true);
110        clusterSizes.setTypeEquals(new ArrayType(BaseType.LONG));
111        new SingletonAttribute(clusterSizes, "_showName");
112
113        initSteps = new PortParameter(this, "initSteps");
114        initSteps.setTypeEquals(BaseType.INT);
115        initSteps.getPort().setTypeEquals(BaseType.INT);
116        new SingletonAttribute(initSteps.getPort(), "_showName");
117        initSteps.setToken(new IntToken(10));
118
119        initializationMode = new StringParameter(this, "initializationMode");
120        initializationMode.addChoice("random");
121        initializationMode.addChoice("k-means||");
122        initializationMode.setExpression("random");
123
124        entryName = new TypedIOPort(this, "entryName", true, false);
125        entryName.setTypeEquals(BaseType.STRING);
126        new SingletonAttribute(entryName, "_showName");
127
128        entryNameCol = new StringParameter(this, "entryNameCol");
129        entryNameCol.setToken("name");
130
131        stdColName = new StringParameter(this, "stdColName");
132        stdColName.setToken("std_vector");
133        
134        outFilepath = new StringParameter(this, "outFilepath");
135        outFilepath.setToken("clusterCenters.json");
136    }
137
138
139    
140    /*
141     * initialize:
142     * instantiate JsonObject to store cluster centers
143     */
144    @Override
145    public void initialize() throws IllegalActionException {
146        super.initialize();
147        // initialize _fileObj
148        _fileObj = new JsonObject();
149    }
150    
151
152    /*
153     * fire:
154     * process the incoming dataframe
155     */ 
156    @Override
157    public void fire() throws IllegalActionException {
158
159        super.fire();
160
161        // get standardized dataframe
162        final Dataset<Row> df = 
163            (Dataset<Row>) ((ObjectToken)data.get(0)).getValue();
164
165        // get entryId for saving current mean/std stats to file
166        // NOTE: disallowing pure whitespace names
167        String entryId = "";
168        if(entryName.getWidth() > 0)
169            entryId = ((StringToken)entryName.get(0)).stringValue().trim(); 
170        else {
171            String col = entryNameCol.stringValue().trim();
172            if(col == "")
173                System.out.println("No entryName/Col, not saving to file!");
174            else try {
175                    Row fr = df.first();
176                    entryId = (String) fr.getString(fr.fieldIndex(col));
177                } catch(IllegalArgumentException e) {
178                    throw new IllegalActionException("Can't find col name!");
179                } catch(UnsupportedOperationException e) {
180                    throw new IllegalActionException("No schema in input df!");
181                } catch(ClassCastException e) {
182                    throw new IllegalActionException("Col data not strings!");
183                }
184        }
185        // TODO: catch empty entryId 
186        entryId = entryId.replaceAll("\\s+", "");  // remove all whitespaces
187                
188
189        // update all actor parameters        
190        iterations.update();
191        final int numIterations = ((IntToken)iterations.getToken()).intValue();
192        numClusters.update();
193        final int numClustersVal=((IntToken)numClusters.getToken()).intValue();
194        seed.update();
195        final int seedVal = ((IntToken)seed.getToken()).intValue();
196        final String initMode = initializationMode.stringValue();
197        initSteps.update();
198        final int numInitSteps = ((IntToken)initSteps.getToken()).intValue();
199
200
201        // set up k-means 
202        final KMeans kmeans = new KMeans();
203        kmeans.setMaxIter(numIterations);
204        kmeans.setK(numClustersVal);
205        kmeans.setSeed(seedVal);
206        kmeans.setInitMode(initMode);
207        kmeans.setInitSteps(numInitSteps);
208        // NOTE: KMeansClustering depends on the standard vector column
209        kmeans.setFeaturesCol(stdColName.stringValue().trim());
210                
211        KMeansModel model = kmeans.fit(df.cache());     
212        double WSSSE = model.computeCost(df);
213        Vector[] cCenters = model.clusterCenters();
214
215        KMeansSummary summary = model.summary();
216        long[] clusterSizesArray = summary.clusterSizes();
217
218
219        // build JsonArray of JsonArray to represent array of Vector
220        JsonArray clusterArray = new JsonArray();
221        for (int i = 0; i < cCenters.length; i++) {
222            JsonArray arr =
223                new JsonArray(Arrays.toString(cCenters[i].toArray()));
224            clusterArray.add(arr);
225        }
226                
227        // output results
228        if(entryId != "") { // if entryId present, save centers to file
229            // save JSON object for current entry
230            JsonObject entryObj = new JsonObject();
231            entryObj.put("clusterCenters", clusterArray);
232            _fileObj.put(entryId, entryObj);
233
234            // broadcast station name and centers array for SAIdentify
235            // NOTE: this depends on entryId.
236            entryObj.put("name", entryId);
237            centers.broadcast(new StringToken(entryObj.encode())); 
238        }
239        else { // else broadcast cluster centers without entry ID
240            centers.broadcast(new StringToken(clusterArray.encode())); 
241        }
242                
243        // broadcast WSSSE
244        if(error.numberOfSinks() > 0)
245            error.broadcast(new DoubleToken(WSSSE));
246
247        // broadcast cluster sizes array
248        if(clusterSizes.numberOfSinks() > 0) {
249            Token[] cSizesTokenArray = new Token[clusterSizesArray.length];
250            for(int i = 0; i < clusterSizesArray.length; i++)
251                cSizesTokenArray[i] = new LongToken(clusterSizesArray[i]);
252            clusterSizes.broadcast(new ArrayToken(cSizesTokenArray)); 
253        }
254                
255        // DEBUG outputs
256        if(_debugging) {
257            //System.err.println(clusterArray.encode());
258            System.out.println("\nDEBUG: WSSSE:" + WSSSE);
259                
260            System.out.println("----- Cluster Sizes: -----");
261            for (long size: clusterSizesArray) {
262                System.out.print(size + " ");
263            }
264            System.out.println();
265                        
266            System.out.println("---- Cluster Centers: -----");
267            for (Vector centroid: cCenters) {
268                System.out.print(centroid + " ");
269            }
270            System.out.println();
271        }
272    }
273
274
275    /* 
276     * wrap-up:
277     * write JsonObject containing all cluster centers to file
278     */
279    @Override
280    public void wrapup() throws IllegalActionException {
281
282        super.wrapup(); 
283
284        PrintWriter writer;
285
286        // try to open existing file and overwrite with new data.
287        // if no existing file found, create new file and write to it.
288        try{
289            // try to open existing file
290            FileInputStream is =
291                new FileInputStream(outFilepath.getValueAsString());
292            BufferedReader reader =
293                new BufferedReader(new InputStreamReader(is));
294
295            // read the file as a single string, for JsonObject(String)
296            StringBuilder sb = new StringBuilder();
297            String line = reader.readLine();
298            while (line != null) {
299                sb.append(line);
300                line = reader.readLine();
301            }
302            reader.close();  // close file for writing
303
304            // file read, build new JSON object for original data
305            JsonObject origFileObj = new JsonObject(sb.toString());
306            // get list of names for stations with new data
307            Set<String> stationList = _fileObj.fieldNames();
308
309            // replace old data in origFileObj with new data
310            for (String station : stationList) {
311                origFileObj.remove(station);  // ignored if station DNE
312                origFileObj.put(station, _fileObj.getJsonObject(station));
313            }
314
315            // write updated data to file
316            try{
317                // open file
318                writer = new PrintWriter(outFilepath.getValueAsString());
319                
320                // write _fileObj to file
321                writer.println(origFileObj.encodePrettily()); // updated data
322        
323                // close the writer
324                writer.flush();
325                writer.close();
326            }catch(Exception ei1){System.err.println("Failed to open file!");}
327
328        } catch (Exception fe) { // no original file, create a new one
329
330            System.out.println("No original data file found.");
331            
332            try{
333                // open file
334                writer = new PrintWriter(outFilepath.getValueAsString());
335                
336                // write _fileObj to file
337                writer.println(_fileObj.encodePrettily());
338        
339                // close the writer
340                writer.flush();
341                writer.close();
342            }catch(Exception ei2){System.err.println("Failed to open file!");}
343        }
344
345    }
346
347
348    
349    /** The input vectors. */ 
350    public TypedIOPort data;
351        
352    /** The number of clusters. */
353    public PortParameter numClusters;
354    
355    /** The number of runs of the algorithm to execute in parallel. */
356    //public PortParameter numRuns;
357
358    /** The maximum number of iterations to run. */
359    public PortParameter iterations;
360    
361    /** The initialization mode for KMeans */
362    public StringParameter initializationMode;
363
364    /** The number of initialization steps for kmeans|| */
365    public PortParameter initSteps;
366
367    /** The random seed value to use for cluster initialization . */
368    public PortParameter seed;
369  
370    /** The sum of squared distances to their nearest center. */
371    public TypedIOPort error;
372    
373    /** The center of the clusters. */
374    public TypedIOPort centers;
375    
376    /** The size of each cluster. */
377    public TypedIOPort clusterSizes;
378
379    /** Name, or ID, for the dataframe being processed */ 
380    public TypedIOPort entryName;
381
382    /** Name of the entry name column */
383    public StringParameter entryNameCol;
384
385    /** Name of the standardized vector column in the input dataframe */
386    public StringParameter stdColName;  
387
388    /** Filepath for JSON output **/
389    public StringParameter outFilepath;
390
391    /* ========== */
392    
393    /** _fileObj to write to file **/
394    private JsonObject _fileObj;
395
396    /** The MLlib KMeans model. */
397    private KMeansModel _model;
398
399}