001/* 002 * Copyright (c) 2016-2017 The Regents of the University of California. 003 * All rights reserved. 004 * 005 * '$Author: crawl $' 006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 007 * '$Revision: 34661 $' 008 * 009 * Permission is hereby granted, without written agreement and without 010 * license or royalty fees, to use, copy, modify, and distribute this 011 * software and its documentation for any purpose, provided that the above 012 * copyright notice and the following two paragraphs appear in all copies 013 * of this software. 014 * 015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 019 * SUCH DAMAGE. 020 * 021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 026 * ENHANCEMENTS, OR MODIFICATIONS. 027 * 028 */ 029 030package org.kepler.spark.mllib; 031 032import java.io.BufferedReader; 033import java.io.FileInputStream; 034import java.io.InputStreamReader; 035import java.util.HashMap; 036import java.util.List; 037import java.util.Map; 038import java.util.Set; 039 040import org.apache.spark.api.java.function.MapFunction; 041import org.apache.spark.broadcast.Broadcast; 042import org.apache.spark.ml.linalg.DenseVector; 043import org.apache.spark.ml.linalg.SQLDataTypes; 044import org.apache.spark.ml.linalg.Vector; 045import org.apache.spark.ml.linalg.Vectors; 046import org.apache.spark.sql.Dataset; 047import org.apache.spark.sql.Encoder; 048import org.apache.spark.sql.Encoders; 049import org.apache.spark.sql.Row; 050import org.apache.spark.sql.RowFactory; 051import org.apache.spark.sql.types.StructType; 052import org.kepler.spark.actor.SparkSQLActor; 053 054import io.vertx.core.json.JsonArray; 055import io.vertx.core.json.JsonObject; 056import ptolemy.actor.TypedIOPort; 057import ptolemy.data.ObjectToken; 058import ptolemy.data.StringToken; 059import ptolemy.data.expr.StringParameter; 060import ptolemy.data.type.BaseType; 061import ptolemy.data.type.ObjectType; 062import ptolemy.kernel.CompositeEntity; 063import ptolemy.kernel.util.IllegalActionException; 064import ptolemy.kernel.util.NameDuplicationException; 065import ptolemy.kernel.util.SingletonAttribute; 066 067/** 068 * @author Dylan Uys, Jiaxin Li 069 * 070 * Standardizes the values in the incoming DataFrame by subtracting the mean 071 * and dividing by the standard deviation. 072 */ 073public class StandardizeApply extends SparkSQLActor { 074 075 public StandardizeApply(CompositeEntity container, String name) 076 throws IllegalActionException, NameDuplicationException { 077 super(container, name); 078 079 inData = new TypedIOPort(this, "inData", true, false); 080 inData.setTypeEquals(new ObjectType(Dataset.class)); 081 new SingletonAttribute(inData, "_showName"); 082 083 outData = new TypedIOPort(this, "outData", false, true); 084 outData.setTypeEquals(new ObjectType(Dataset.class)); 085 new SingletonAttribute(outData, "_showName"); 086 087 test = new TypedIOPort(this, "test", false, true); 088 test.setTypeEquals(BaseType.STRING); 089 new SingletonAttribute(test, "_showName"); 090 091 inFilepath = new StringParameter(this, "inFilepath"); 092 inFilepath.setToken("meanstd.json"); 093 094 entryNameCol = new StringParameter(this, "entryNameCol"); 095 entryNameCol.setToken("name"); 096 097 stdCol = new StringParameter(this, "stdCol"); 098 stdCol.setToken("std_vector"); 099 100 inCol = new StringParameter(this, "inCol"); 101 inCol.setToken("in_vector"); 102 } 103 104 @Override 105 public void initialize() throws IllegalActionException { 106 107 super.initialize(); 108 109 // read model file 110 StringBuilder sb; 111 try { 112 FileInputStream is = new FileInputStream(inFilepath.stringValue()); 113 BufferedReader reader = 114 new BufferedReader(new InputStreamReader(is)); 115 // read in the entire file as a single string, for JsonObject 116 sb = new StringBuilder(); 117 String line = reader.readLine(); 118 while (line != null) { 119 sb.append(line); 120 line = reader.readLine(); 121 } 122 reader.close(); // close file 123 } catch(Exception e) { 124 throw new IllegalActionException("ERROR: Cannot read file!"); 125 } 126 127 128 // Create json object from file string, extract station specific data 129 JsonObject jsonObj = new JsonObject(sb.toString()); 130 Set<String> stations = jsonObj.fieldNames(); 131 132 // These will be the broadcasted objects. The keys are String station 133 // codes, the values are List<Double>'s of means/stddevs 134 Map<String, List<Double>> meansByStation = 135 new HashMap<String, List<Double>>(); 136 Map<String, List<Double>> stdsByStation = 137 new HashMap<String, List<Double>>(); 138 139 // Iterate through the stationCodes present in the JsonObject's keys, 140 // adding each respective array of means/stddevs to the Map 141 // instantiated above 142 for (String stationCode: stations) { 143 144 // stationObj format: "MW": [[...], [...], ...] 145 JsonObject stationObj = jsonObj.getJsonObject(stationCode); 146 JsonArray means = stationObj.getJsonArray("mean"); 147 JsonArray stds = stationObj.getJsonArray("std"); 148 149 List<Double> meansList = means.getList(); 150 List<Double> stdsList = stds.getList(); 151 152 stdsByStation.put(stationCode, stdsList); 153 meansByStation.put(stationCode, meansList); 154 } 155 156 // Broadcast 157 _bcastMean = _context.broadcast(meansByStation); 158 _bcastStddev = _context.broadcast(stdsByStation); 159 } 160 161 162 @Override 163 public void fire() throws IllegalActionException { 164 165 super.fire(); 166 167 // get name column and input column 168 String inColName = inCol.stringValue().trim(); 169 String nameCol = entryNameCol.stringValue().trim(); 170 171 // Read data 172 Dataset<Row> inDf = 173 (Dataset<Row>)((ObjectToken)inData.get(0)).getValue(); 174 Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class); 175 176 // DEBUG 177 /* 178 if(_debugging) { 179 System.out.println("Standardize Before:"); 180 inDf.printSchema(); 181 } 182 */ 183 184 Map<String,List<Double>> mean = _bcastMean.value();; 185 Map<String,List<Double>> std = _bcastStddev.value(); 186 187 // if Standadizer throws exception, don't output any token 188 try{ 189 190 Standardizer stder = new Standardizer(mean,std,inColName,nameCol); 191 Dataset<Row> outDf = inDf.map(stder, rowEncoder); 192 193 // Adding schema back to dataframe, after map function 194 // NOTE: new schema contains one more column, std_vector 195 StructType outSchema = inDf.schema(); 196 outSchema = outSchema.add(stdCol.stringValue(), 197 SQLDataTypes.VectorType()); 198 outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(), 199 outSchema); 200 201 //DEBUG 202 /* 203 if(_debugging) { 204 System.out.println("Standardize After:"); 205 outDf.printSchema(); 206 } 207 */ 208 209 // Output standardized dataframe 210 outData.broadcast(new ObjectToken(outDf, Dataset.class)); 211 if (test.numberOfSinks() > 0) { 212 test.broadcast(new StringToken(outDf.first().toString())); 213 } 214 215 } catch(Exception e) { 216 System.err.println("ERROR: Standardize Apply failed! " + 217 e.toString()); 218 } 219 220 } 221 222 /* 223 * Private helper class for standardizing rows using map() 224 */ 225 private static class Standardizer implements MapFunction<Row, Row> { 226 227 private Map<String,List<Double>> meanList; 228 private Map<String,List<Double>> stdList; 229 private String inCol; 230 private String idCol; 231 232 public Standardizer(Map<String,List<Double>> mean, 233 Map<String,List<Double>> std, String inc, 234 String idc) { 235 meanList = mean; 236 stdList = std; 237 inCol = inc; 238 idCol = idc; 239 } 240 241 @Override 242 public Row call(Row in) throws Exception { 243 244 // get entry ID (e.g. station ID) from input row 245 String name = in.getString(in.fieldIndex(idCol)); 246 name = name.replaceAll("\\s+", ""); // clean up name string 247 248 // TODO: check for dimension mismatch 249 250 // get all data fields from the row, then standardize 251 // NOTE: requires identical vec order b/w model and input vector 252 try { 253 // get input (raw) vector as double[] from input row 254 DenseVector inVector; 255 try{ 256 inVector = (DenseVector)in.get(in.fieldIndex(inCol)); 257 } catch(IllegalArgumentException e) { 258 System.err.println("ERROR: Can't find inCol: " + inCol); 259 throw e; 260 } 261 double[] inVec = inVector.toArray(); 262 263 // create output vector and standardize from input vector 264 double[] out = new double[inVector.size()]; 265 for(int i = 0; i < inVector.size(); i++) { 266 double m = meanList.get(name).get(i); 267 double s = stdList.get(name).get(i); 268 out[i] = inVec[i]; 269 out[i] = out[i] - m; 270 out[i] = out[i] / s; 271 } 272 273 // assemble the standardized feature vector 274 Vector stdVector = Vectors.dense(out); 275 276 // Assemble output row 277 Object[] outFields = new Object[in.size()+1]; 278 for(int i = 0; i < in.size(); i++) // copy existing fields 279 outFields[i] = in.get(i); 280 outFields[in.size()] = stdVector; // append vector field 281 282 return RowFactory.create(outFields); 283 284 } catch (NullPointerException e) { 285 // catching the case where station name from the websocket 286 // can't be found in the train data 287 System.err.println("ERROR: Cannot find mean/std info:" + name); 288 throw e; 289 290 } catch (Exception e) { 291 System.err.println("ERROR: Standardizer failed!"+e.toString()); 292 throw e; 293 } 294 295 } 296 } 297 298 /** Dataframe input */ 299 public TypedIOPort inData; 300 301 /** Dataframe output */ 302 public TypedIOPort outData; // Standardized data in Dataset 303 304 /** First row of input dataframe, for testing */ 305 public TypedIOPort test; // Outputs the first row of new dataset 306 307 /** Path to model file */ 308 public StringParameter inFilepath; 309 310 /** Name of the input vector column, which contains raw data */ 311 public StringParameter inCol; 312 313 /** Name of entry name column */ 314 public StringParameter entryNameCol; 315 316 /** Name of standardized vector column to append */ 317 public StringParameter stdCol; 318 319 320 private Broadcast<Map<String,List<Double>>> _bcastMean; 321 private Broadcast<Map<String,List<Double>>> _bcastStddev; 322 323}