001/* 002 * Copyright (c) 2016-2017 The Regents of the University of California. 003 * All rights reserved. 004 * 005 * '$Author: crawl $' 006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 007 * '$Revision: 34661 $' 008 * 009 * Permission is hereby granted, without written agreement and without 010 * license or royalty fees, to use, copy, modify, and distribute this 011 * software and its documentation for any purpose, provided that the above 012 * copyright notice and the following two paragraphs appear in all copies 013 * of this software. 014 * 015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 019 * SUCH DAMAGE. 020 * 021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 026 * ENHANCEMENTS, OR MODIFICATIONS. 027 * 028 */ 029 030package org.kepler.spark.mllib; 031 032import java.io.FileInputStream; 033import java.io.FileNotFoundException; 034import java.io.IOException; 035import java.io.InputStream; 036import java.io.InputStreamReader; 037import java.io.Reader; 038import java.io.UnsupportedEncodingException; 039import java.util.HashMap; 040import java.util.Map; 041import java.util.Set; 042 043import org.apache.commons.math3.ml.distance.EuclideanDistance; 044import org.apache.spark.api.java.function.MapFunction; 045import org.apache.spark.broadcast.Broadcast; 046import org.apache.spark.ml.linalg.DenseVector; 047import org.apache.spark.ml.linalg.SQLDataTypes; 048import org.apache.spark.ml.linalg.Vector; 049import org.apache.spark.ml.linalg.Vectors; 050import org.apache.spark.sql.Dataset; 051import org.apache.spark.sql.Encoder; 052import org.apache.spark.sql.Encoders; 053import org.apache.spark.sql.Row; 054import org.apache.spark.sql.RowFactory; 055import org.apache.spark.sql.types.StructType; 056import org.kepler.spark.actor.SparkSQLActor; 057 058import io.vertx.core.json.JsonArray; 059import io.vertx.core.json.JsonObject; 060import ptolemy.actor.TypedIOPort; 061import ptolemy.data.ObjectToken; 062import ptolemy.data.StringToken; 063import ptolemy.data.expr.StringParameter; 064import ptolemy.data.type.BaseType; 065import ptolemy.data.type.ObjectType; 066import ptolemy.kernel.CompositeEntity; 067import ptolemy.kernel.util.IllegalActionException; 068import ptolemy.kernel.util.NameDuplicationException; 069import ptolemy.kernel.util.SingletonAttribute; 070 071/** 072 * @author Dylan Uys, Jiaxin Li 073 * 074 * Takes a column of standardized vectors from the input dataframe, and 075 * calculates the kmeans distance space vector (i.e. the distance from the 076 * current row's measurement to all cluster centers in the model file). 077 * Outputs results by appending a vector to the dataframe. 078 */ 079public class KMeansApply extends SparkSQLActor { 080 081 public KMeansApply(CompositeEntity container, String name) 082 throws IllegalActionException, NameDuplicationException { 083 084 super(container, name); 085 086 inData = new TypedIOPort(this, "data", true, false); 087 inData.setTypeEquals(new ObjectType(Dataset.class)); 088 new SingletonAttribute(inData, "_showName"); 089 090 outData = new TypedIOPort(this, "outData", false, true); 091 outData.setTypeEquals(new ObjectType(Dataset.class)); 092 new SingletonAttribute(outData, "_showName"); 093 094 inFilepath = new StringParameter(this, "inFilepath"); 095 inFilepath.setToken("clusterCenters.json"); 096 097 entryNameCol = new StringParameter(this, "entryNameCol"); 098 entryNameCol.setToken("name"); 099 100 stdCol = new StringParameter(this, "stdCol"); 101 stdCol.setToken("std_vector"); 102 103 kmeansCol = new StringParameter(this, "kmeansCol"); 104 kmeansCol.setToken("kmeans_dist_vec"); 105 106 debugOutput = new TypedIOPort(this, "debugOutput", false, true); 107 debugOutput.setTypeEquals(BaseType.STRING); 108 new SingletonAttribute(debugOutput, "_showName"); 109 110 } 111 112 @Override 113 public void initialize() throws IllegalActionException { 114 115 super.initialize(); 116 117 InputStream is; 118 Reader isr; 119 StringBuilder sb; 120 121 // Read entire JSON file into one string, to pass to JsonObject ctor 122 try { 123 is = new FileInputStream(inFilepath.getValueAsString()); 124 try { 125 isr = new InputStreamReader(is, "UTF-8"); 126 } catch (UnsupportedEncodingException e) { 127 System.err.println("ERROR: Problem with file's encoding"); 128 return; 129 } 130 131 // use StringBuilder to construct JSON string 132 sb = new StringBuilder(512); 133 try { 134 int c = 0; 135 while ((c = isr.read()) != -1) { 136 sb.append((char) c); 137 } 138 } catch (IOException e) { 139 throw new RuntimeException(e); 140 } 141 142 } catch (FileNotFoundException e) { 143 throw new IllegalActionException(this, e, "ERROR: File not found:" 144 + inFilepath.getValueAsString()); // workflow exception 145 } 146 147 // Create a JsonObject from the string, then get all the station keys 148 JsonObject jsonObj = new JsonObject(sb.toString()); 149 Set<String> stations = jsonObj.fieldNames(); 150 151 // This will be the broadcasted object. The keys are String station 152 // codes, the values are double[][] of cluster centers. 153 Map<String, double[][]> centersByStation = 154 new HashMap<String, double[][]>(); 155 156 // Iterate through the stationCodes present in the JsonObject's keys, 157 // adding each respective array of cluster centers to the Map 158 // instantiated above 159 for (String stationCode: stations) { 160 161 // stationObj format: "MW": [[...], [...], ...] 162 JsonObject stationObj = jsonObj.getJsonObject(stationCode); 163 JsonArray centers = stationObj.getJsonArray("clusterCenters"); 164 165 double[][] centersList; 166 centersList = new double[centers.size()][centers.getJsonArray(0).size()]; 167 for (int i = 0; i < centers.size(); i++) { 168 JsonArray currCenter = centers.getJsonArray(i); 169 170 for (int j = 0; j < currCenter.size(); j++) { 171 centersList[i][j] = currCenter.getDouble(j); 172 } 173 } 174 175 centersByStation.put(stationCode, centersList); 176 } 177 178 // broadcast 179 _bcastCenters = _context.broadcast(centersByStation); 180 181 } 182 183 @Override 184 public void fire() throws IllegalActionException { 185 186 super.fire(); 187 188 Dataset<Row>inDf=(Dataset<Row>)((ObjectToken)inData.get(0)).getValue(); 189 190 //DEBUG 191 /* 192 if(_debugging) 193 inDf.printSchema(); 194 */ 195 //inDf.show(); 196 197 // Encoder to serialize the row output of the map call 198 Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class); 199 200 // Get broadcast var to pass to ClusterClassifier 201 // TODO: do we need to have this in fire()? 202 Map<String, double[][]> stationCenters = _bcastCenters.value(); 203 204 // Classifier 205 ClusterClassifier classifier = 206 new ClusterClassifier(stationCenters, 207 entryNameCol.stringValue().trim(), 208 stdCol.stringValue().trim()); 209 210 // Calculate nearest cluster center by Euclidean distance for each row 211 // in the Dataset 212 // NOTE: If exceptions are encountered, won't output any tokens 213 try{ 214 // use map function to identify center ID for each row 215 Dataset<Row> outDf = inDf.map(classifier, rowEncoder); 216 217 // add schema column for kmeans distance-space vector 218 StructType outSchema = 219 inDf.schema().add(kmeansCol.stringValue(), 220 SQLDataTypes.VectorType()); 221 outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(), 222 outSchema); 223 224 // output results 225 outData.broadcast(new ObjectToken(outDf, Dataset.class)); 226 227 //DEBUG 228 /* 229 if(_debugging) 230 outDf.printSchema(); 231 */ 232 if (debugOutput.numberOfSinks() > 0) 233 debugOutput.broadcast(new StringToken(outDf.first().toString())); 234 235 } catch (Exception e) { 236 System.err.println("ERROR: KMeansApply failed! " + e.toString()); 237 } 238 } 239 240 241 /* 242 * Private helper class for computing k-means distance space vectors 243 * TODO: change class name 244 */ 245 private static class ClusterClassifier implements MapFunction<Row, Row> { 246 247 private Map<String, double[][]> clusterCenters; 248 private String idCol, stdCol; 249 250 public ClusterClassifier(Map<String, double[][]> centers, 251 String id, String std) { 252 clusterCenters = centers; 253 idCol = id; 254 stdCol = std; 255 } 256 257 @Override 258 public Row call(Row datum) throws Exception { 259 260 EuclideanDistance ed = new EuclideanDistance(); 261 262 // first, get array of cluster centers from the map 263 String name; 264 try { 265 // get entry name 266 name = datum.getString(datum.fieldIndex(idCol)); 267 name = name.replaceAll("\\s+", ""); // clean up name string 268 } catch(IllegalArgumentException e) { 269 System.err.println("ERROR: Can't find entry column: " + idCol); 270 throw e; 271 } 272 273 // get trained cluster data for current station 274 int numClusters, vectorDim; 275 double[][] centers; 276 try { 277 centers = clusterCenters.get(name); 278 // get number of clusters per entry, and cluster dimension 279 numClusters = centers.length; 280 vectorDim = centers[0].length; 281 } catch(NullPointerException e) { 282 System.err.println("ERROR: No cluster center info for " + 283 "station: " + name); 284 throw e; 285 } 286 287 // get standardized vector as double[] from input row 288 DenseVector stdVector; 289 try{ 290 stdVector = (DenseVector)datum.get(datum.fieldIndex(stdCol)); 291 } catch(IllegalArgumentException e) { 292 System.err.println("ERROR: Can't find stdCol: " + stdCol); 293 throw e; 294 } 295 double[] d = stdVector.toArray(); 296 297 // calculate ED to all centers, assemble distance-space vector 298 double[] distArray = new double[numClusters]; 299 for (int i = 0; i < numClusters; i++) { 300 double[] c = centers[i]; 301 distArray[i] = ed.compute(c, d); 302 } 303 Vector distVector = Vectors.dense(distArray); 304 305 // Assemble the output row by appending the dist vector field 306 Object[] outFields = new Object[datum.size()+1]; 307 for(int i = 0; i < datum.size(); i++) 308 outFields[i] = datum.get(i); 309 outFields[datum.size()] = distVector; // dist vector 310 311 return RowFactory.create(outFields); 312 } 313 } 314 315 /** Dataframe input */ 316 public TypedIOPort inData; 317 318 /** Dataframe output w/kmeans distance space vector */ 319 public TypedIOPort outData; 320 321 /** Path to model file */ 322 public StringParameter inFilepath; 323 324 /** Name of entry name column in input dataframe */ 325 public StringParameter entryNameCol; 326 327 /** Name of standard vector column */ 328 public StringParameter stdCol; 329 330 /** Name to kmeans vector column (to append to dataframe) */ 331 public StringParameter kmeansCol; 332 333 /** Debug output port */ 334 public TypedIOPort debugOutput; 335 336 337 /* broadcasted cluster centers map */ 338 private Broadcast<Map<String, double[][]>> _bcastCenters; 339 340} 341 342