001/* 002 * Copyright (c) 2016-2017 The Regents of the University of California. 003 * All rights reserved. 004 * 005 * '$Author: crawl $' 006 * '$Date: 2018-02-07 01:00:17 +0000 (Wed, 07 Feb 2018) $' 007 * '$Revision: 34658 $' 008 * 009 * Permission is hereby granted, without written agreement and without 010 * license or royalty fees, to use, copy, modify, and distribute this 011 * software and its documentation for any purpose, provided that the above 012 * copyright notice and the following two paragraphs appear in all copies 013 * of this software. 014 * 015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 019 * SUCH DAMAGE. 020 * 021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 026 * ENHANCEMENTS, OR MODIFICATIONS. 027 * 028 */ 029 030package org.kepler.spark.mllib; 031 032import java.io.BufferedReader; 033import java.io.FileInputStream; 034import java.io.InputStreamReader; 035import java.io.PrintWriter; 036import java.util.Arrays; 037import java.util.Set; 038 039import org.apache.spark.ml.clustering.KMeans; 040import org.apache.spark.ml.clustering.KMeansModel; 041import org.apache.spark.ml.clustering.KMeansSummary; 042import org.apache.spark.ml.linalg.Vector; 043import org.apache.spark.sql.Dataset; 044import org.apache.spark.sql.Row; 045 046import io.vertx.core.json.JsonArray; 047import io.vertx.core.json.JsonObject; 048import ptolemy.actor.TypedAtomicActor; 049import ptolemy.actor.TypedIOPort; 050import ptolemy.actor.parameters.PortParameter; 051import ptolemy.data.ArrayToken; 052import ptolemy.data.DoubleToken; 053import ptolemy.data.IntToken; 054import ptolemy.data.LongToken; 055import ptolemy.data.ObjectToken; 056import ptolemy.data.StringToken; 057import ptolemy.data.Token; 058import ptolemy.data.expr.StringParameter; 059import ptolemy.data.type.ArrayType; 060import ptolemy.data.type.BaseType; 061import ptolemy.data.type.ObjectType; 062import ptolemy.kernel.CompositeEntity; 063import ptolemy.kernel.util.IllegalActionException; 064import ptolemy.kernel.util.NameDuplicationException; 065import ptolemy.kernel.util.SingletonAttribute; 066 067/** 068 * @author Dylan Uys, Jiaxin Li 069 * 070 * This actor calls Spark's KMeans API to perform clustering on the input data 071 * frame, specifically on a standardized feature column specified by the user. 072 * The cluster centers are output as a JSON string on the output port. 073 */ 074public class KMeansClustering extends TypedAtomicActor { 075 076 public KMeansClustering(CompositeEntity container, String name) 077 throws IllegalActionException, NameDuplicationException { 078 super(container, name); 079 080 081 data = new TypedIOPort(this, "data", true, false); 082 data.setTypeEquals(new ObjectType(Dataset.class)); 083 new SingletonAttribute(data, "_showName"); 084 085 numClusters = new PortParameter(this, "numClusters"); 086 numClusters.setTypeEquals(BaseType.INT); 087 numClusters.getPort().setTypeEquals(BaseType.INT); 088 new SingletonAttribute(numClusters.getPort(), "_showName"); 089 090 iterations = new PortParameter(this, "iterations"); 091 iterations.setTypeEquals(BaseType.INT); 092 iterations.getPort().setTypeEquals(BaseType.INT); 093 new SingletonAttribute(iterations.getPort(), "_showName"); 094 iterations.setToken(new IntToken(10)); 095 096 seed = new PortParameter(this, "seed"); 097 seed.setTypeEquals(BaseType.INT); 098 seed.getPort().setTypeEquals(BaseType.INT); 099 new SingletonAttribute(seed.getPort(), "_showName"); 100 101 error = new TypedIOPort(this, "error", false, true); 102 error.setTypeEquals(BaseType.DOUBLE); 103 new SingletonAttribute(error, "_showName"); 104 105 centers = new TypedIOPort(this, "centers", false, true); 106 centers.setTypeEquals(BaseType.STRING); 107 new SingletonAttribute(centers, "_showName"); 108 109 clusterSizes = new TypedIOPort(this, "clusterSizes", false, true); 110 clusterSizes.setTypeEquals(new ArrayType(BaseType.LONG)); 111 new SingletonAttribute(clusterSizes, "_showName"); 112 113 initSteps = new PortParameter(this, "initSteps"); 114 initSteps.setTypeEquals(BaseType.INT); 115 initSteps.getPort().setTypeEquals(BaseType.INT); 116 new SingletonAttribute(initSteps.getPort(), "_showName"); 117 initSteps.setToken(new IntToken(10)); 118 119 initializationMode = new StringParameter(this, "initializationMode"); 120 initializationMode.addChoice("random"); 121 initializationMode.addChoice("k-means||"); 122 initializationMode.setExpression("random"); 123 124 entryName = new TypedIOPort(this, "entryName", true, false); 125 entryName.setTypeEquals(BaseType.STRING); 126 new SingletonAttribute(entryName, "_showName"); 127 128 entryNameCol = new StringParameter(this, "entryNameCol"); 129 entryNameCol.setToken("name"); 130 131 stdColName = new StringParameter(this, "stdColName"); 132 stdColName.setToken("std_vector"); 133 134 outFilepath = new StringParameter(this, "outFilepath"); 135 outFilepath.setToken("clusterCenters.json"); 136 } 137 138 139 140 /* 141 * initialize: 142 * instantiate JsonObject to store cluster centers 143 */ 144 @Override 145 public void initialize() throws IllegalActionException { 146 super.initialize(); 147 // initialize _fileObj 148 _fileObj = new JsonObject(); 149 } 150 151 152 /* 153 * fire: 154 * process the incoming dataframe 155 */ 156 @Override 157 public void fire() throws IllegalActionException { 158 159 super.fire(); 160 161 // get standardized dataframe 162 final Dataset<Row> df = 163 (Dataset<Row>) ((ObjectToken)data.get(0)).getValue(); 164 165 // get entryId for saving current mean/std stats to file 166 // NOTE: disallowing pure whitespace names 167 String entryId = ""; 168 if(entryName.getWidth() > 0) 169 entryId = ((StringToken)entryName.get(0)).stringValue().trim(); 170 else { 171 String col = entryNameCol.stringValue().trim(); 172 if(col == "") 173 System.out.println("No entryName/Col, not saving to file!"); 174 else try { 175 Row fr = df.first(); 176 entryId = (String) fr.getString(fr.fieldIndex(col)); 177 } catch(IllegalArgumentException e) { 178 throw new IllegalActionException("Can't find col name!"); 179 } catch(UnsupportedOperationException e) { 180 throw new IllegalActionException("No schema in input df!"); 181 } catch(ClassCastException e) { 182 throw new IllegalActionException("Col data not strings!"); 183 } 184 } 185 // TODO: catch empty entryId 186 entryId = entryId.replaceAll("\\s+", ""); // remove all whitespaces 187 188 189 // update all actor parameters 190 iterations.update(); 191 final int numIterations = ((IntToken)iterations.getToken()).intValue(); 192 numClusters.update(); 193 final int numClustersVal=((IntToken)numClusters.getToken()).intValue(); 194 seed.update(); 195 final int seedVal = ((IntToken)seed.getToken()).intValue(); 196 final String initMode = initializationMode.stringValue(); 197 initSteps.update(); 198 final int numInitSteps = ((IntToken)initSteps.getToken()).intValue(); 199 200 201 // set up k-means 202 final KMeans kmeans = new KMeans(); 203 kmeans.setMaxIter(numIterations); 204 kmeans.setK(numClustersVal); 205 kmeans.setSeed(seedVal); 206 kmeans.setInitMode(initMode); 207 kmeans.setInitSteps(numInitSteps); 208 // NOTE: KMeansClustering depends on the standard vector column 209 kmeans.setFeaturesCol(stdColName.stringValue().trim()); 210 211 KMeansModel model = kmeans.fit(df.cache()); 212 double WSSSE = model.computeCost(df); 213 Vector[] cCenters = model.clusterCenters(); 214 215 KMeansSummary summary = model.summary(); 216 long[] clusterSizesArray = summary.clusterSizes(); 217 218 219 // build JsonArray of JsonArray to represent array of Vector 220 JsonArray clusterArray = new JsonArray(); 221 for (int i = 0; i < cCenters.length; i++) { 222 JsonArray arr = 223 new JsonArray(Arrays.toString(cCenters[i].toArray())); 224 clusterArray.add(arr); 225 } 226 227 // output results 228 if(entryId != "") { // if entryId present, save centers to file 229 // save JSON object for current entry 230 JsonObject entryObj = new JsonObject(); 231 entryObj.put("clusterCenters", clusterArray); 232 _fileObj.put(entryId, entryObj); 233 234 // broadcast station name and centers array for SAIdentify 235 // NOTE: this depends on entryId. 236 entryObj.put("name", entryId); 237 centers.broadcast(new StringToken(entryObj.encode())); 238 } 239 else { // else broadcast cluster centers without entry ID 240 centers.broadcast(new StringToken(clusterArray.encode())); 241 } 242 243 // broadcast WSSSE 244 if(error.numberOfSinks() > 0) 245 error.broadcast(new DoubleToken(WSSSE)); 246 247 // broadcast cluster sizes array 248 if(clusterSizes.numberOfSinks() > 0) { 249 Token[] cSizesTokenArray = new Token[clusterSizesArray.length]; 250 for(int i = 0; i < clusterSizesArray.length; i++) 251 cSizesTokenArray[i] = new LongToken(clusterSizesArray[i]); 252 clusterSizes.broadcast(new ArrayToken(cSizesTokenArray)); 253 } 254 255 // DEBUG outputs 256 if(_debugging) { 257 //System.err.println(clusterArray.encode()); 258 System.out.println("\nDEBUG: WSSSE:" + WSSSE); 259 260 System.out.println("----- Cluster Sizes: -----"); 261 for (long size: clusterSizesArray) { 262 System.out.print(size + " "); 263 } 264 System.out.println(); 265 266 System.out.println("---- Cluster Centers: -----"); 267 for (Vector centroid: cCenters) { 268 System.out.print(centroid + " "); 269 } 270 System.out.println(); 271 } 272 } 273 274 275 /* 276 * wrap-up: 277 * write JsonObject containing all cluster centers to file 278 */ 279 @Override 280 public void wrapup() throws IllegalActionException { 281 282 super.wrapup(); 283 284 PrintWriter writer; 285 286 // try to open existing file and overwrite with new data. 287 // if no existing file found, create new file and write to it. 288 try{ 289 // try to open existing file 290 FileInputStream is = 291 new FileInputStream(outFilepath.getValueAsString()); 292 BufferedReader reader = 293 new BufferedReader(new InputStreamReader(is)); 294 295 // read the file as a single string, for JsonObject(String) 296 StringBuilder sb = new StringBuilder(); 297 String line = reader.readLine(); 298 while (line != null) { 299 sb.append(line); 300 line = reader.readLine(); 301 } 302 reader.close(); // close file for writing 303 304 // file read, build new JSON object for original data 305 JsonObject origFileObj = new JsonObject(sb.toString()); 306 // get list of names for stations with new data 307 Set<String> stationList = _fileObj.fieldNames(); 308 309 // replace old data in origFileObj with new data 310 for (String station : stationList) { 311 origFileObj.remove(station); // ignored if station DNE 312 origFileObj.put(station, _fileObj.getJsonObject(station)); 313 } 314 315 // write updated data to file 316 try{ 317 // open file 318 writer = new PrintWriter(outFilepath.getValueAsString()); 319 320 // write _fileObj to file 321 writer.println(origFileObj.encodePrettily()); // updated data 322 323 // close the writer 324 writer.flush(); 325 writer.close(); 326 }catch(Exception ei1){System.err.println("Failed to open file!");} 327 328 } catch (Exception fe) { // no original file, create a new one 329 330 System.out.println("No original data file found."); 331 332 try{ 333 // open file 334 writer = new PrintWriter(outFilepath.getValueAsString()); 335 336 // write _fileObj to file 337 writer.println(_fileObj.encodePrettily()); 338 339 // close the writer 340 writer.flush(); 341 writer.close(); 342 }catch(Exception ei2){System.err.println("Failed to open file!");} 343 } 344 345 } 346 347 348 349 /** The input vectors. */ 350 public TypedIOPort data; 351 352 /** The number of clusters. */ 353 public PortParameter numClusters; 354 355 /** The number of runs of the algorithm to execute in parallel. */ 356 //public PortParameter numRuns; 357 358 /** The maximum number of iterations to run. */ 359 public PortParameter iterations; 360 361 /** The initialization mode for KMeans */ 362 public StringParameter initializationMode; 363 364 /** The number of initialization steps for kmeans|| */ 365 public PortParameter initSteps; 366 367 /** The random seed value to use for cluster initialization . */ 368 public PortParameter seed; 369 370 /** The sum of squared distances to their nearest center. */ 371 public TypedIOPort error; 372 373 /** The center of the clusters. */ 374 public TypedIOPort centers; 375 376 /** The size of each cluster. */ 377 public TypedIOPort clusterSizes; 378 379 /** Name, or ID, for the dataframe being processed */ 380 public TypedIOPort entryName; 381 382 /** Name of the entry name column */ 383 public StringParameter entryNameCol; 384 385 /** Name of the standardized vector column in the input dataframe */ 386 public StringParameter stdColName; 387 388 /** Filepath for JSON output **/ 389 public StringParameter outFilepath; 390 391 /* ========== */ 392 393 /** _fileObj to write to file **/ 394 private JsonObject _fileObj; 395 396 /** The MLlib KMeans model. */ 397 private KMeansModel _model; 398 399}