001/* 002 * Copyright (c) 2016-2017 The Regents of the University of California. 003 * All rights reserved. 004 * 005 * '$Author: crawl $' 006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 007 * '$Revision: 34661 $' 008 * 009 * Permission is hereby granted, without written agreement and without 010 * license or royalty fees, to use, copy, modify, and distribute this 011 * software and its documentation for any purpose, provided that the above 012 * copyright notice and the following two paragraphs appear in all copies 013 * of this software. 014 * 015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 019 * SUCH DAMAGE. 020 * 021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 026 * ENHANCEMENTS, OR MODIFICATIONS. 027 * 028 */ 029 030package org.kepler.spark.mllib; 031 032import java.io.BufferedReader; 033import java.io.FileInputStream; 034import java.io.InputStreamReader; 035import java.util.HashMap; 036import java.util.Map; 037import java.util.Set; 038 039import org.apache.spark.api.java.function.MapFunction; 040import org.apache.spark.broadcast.Broadcast; 041import org.apache.spark.ml.linalg.DenseVector; 042import org.apache.spark.sql.Dataset; 043import org.apache.spark.sql.Encoder; 044import org.apache.spark.sql.Encoders; 045import org.apache.spark.sql.Row; 046import org.apache.spark.sql.RowFactory; 047import org.apache.spark.sql.types.DataTypes; 048import org.apache.spark.sql.types.StructType; 049import org.kepler.spark.actor.SparkSQLActor; 050 051import io.vertx.core.json.JsonObject; 052import ptolemy.actor.TypedIOPort; 053import ptolemy.data.DoubleToken; 054import ptolemy.data.IntToken; 055import ptolemy.data.ObjectToken; 056import ptolemy.data.expr.Parameter; 057import ptolemy.data.expr.StringParameter; 058import ptolemy.data.type.BaseType; 059import ptolemy.data.type.ObjectType; 060import ptolemy.kernel.CompositeEntity; 061import ptolemy.kernel.util.IllegalActionException; 062import ptolemy.kernel.util.NameDuplicationException; 063import ptolemy.kernel.util.SingletonAttribute; 064 065 066/** 067 * @author Jiaxin Li 068 * 069 * An actor specifically written for the WIFIRE project. 070 * The actor finds the cluster center the input measurements match to most 071 * closely, basing on pre-trained models, and decides if the measurement 072 * reflects Santa Ana condition. 073 */ 074public class SantaAnaDetect extends SparkSQLActor { 075 076 public SantaAnaDetect(CompositeEntity container, String name) 077 throws IllegalActionException, NameDuplicationException { 078 079 super(container, name); 080 081 inData = new TypedIOPort(this, "inData", true, false); 082 inData.setTypeEquals(new ObjectType(Dataset.class)); 083 new SingletonAttribute(inData, "_showName"); 084 085 saWindowThrs = new Parameter(this, "saWindowThrs"); 086 saWindowThrs.setTypeEquals(BaseType.DOUBLE); 087 saWindowThrs.setToken(new DoubleToken(0.9)); 088 089 saWindowWidth = new Parameter(this, "saWindowWidth"); 090 saWindowWidth.setTypeEquals(BaseType.INT); 091 saWindowWidth.setToken(new IntToken(10)); 092 093 // path to SA Cluster ID file 094 saFilePath = new StringParameter(this, "saFilePath"); 095 saFilePath.setToken("saClusters.json"); 096 097 entryNameCol = new StringParameter(this, "entryNameCol"); 098 entryNameCol.setToken("name"); 099 100 kmeansCol = new StringParameter(this, "kmeansCol"); 101 kmeansCol.setToken("kmeans_dist_vec"); 102 103 clusterIdCol = new StringParameter(this, "clusterIdCol"); 104 clusterIdCol.setToken("clusterId"); 105 106 isSACol = new StringParameter(this, "isSACol"); 107 isSACol.setToken("isSA"); 108 109 // output results in string format (can be displayed) 110 outData = new TypedIOPort(this, "outData", false, true); 111 outData.setTypeEquals(new ObjectType(Dataset.class)); 112 new SingletonAttribute(outData, "_showName"); 113 114 } 115 116 /* initialize: 117 */ 118 @Override 119 public void initialize() throws IllegalActionException { 120 121 super.initialize(); 122 123 // initialize _fileObj 124 _fileObj = new JsonObject(); 125 126 // try to open SA cluster file 127 try{ 128 // open file 129 FileInputStream is = 130 new FileInputStream(saFilePath.getValueAsString()); 131 BufferedReader reader = 132 new BufferedReader(new InputStreamReader(is)); 133 134 // read the file as a single string, for JsonObject(String) 135 StringBuilder sb = new StringBuilder(); 136 String line = reader.readLine(); 137 while (line != null) { 138 sb.append(line); 139 line = reader.readLine(); 140 } 141 reader.close(); // close file 142 143 // file read, build new JSON object 144 _fileObj = new JsonObject(sb.toString()); 145 146 // get list of station names 147 Set<String> stationNames = _fileObj.fieldNames(); 148 149 // build new hashmap for [stationName,clusterId] pairs, 150 // then broadcast to spark 151 Map<String,Integer> saCenters = new HashMap<String,Integer>(); 152 for(String stationName : stationNames) { 153 saCenters.put(stationName, 154 _fileObj.getJsonObject(stationName) 155 .getInteger("saCluster")); 156 } 157 _b_saCenters = _context.broadcast(saCenters); 158 159 // get detection window parameters 160 double sawp = ((DoubleToken)saWindowThrs.getToken()).doubleValue(); 161 int saww = ((IntToken)saWindowWidth.getToken()).intValue(); 162 163 // NOTE: creating _classifier here so that counters can persist 164 _classifier = new SA_Classifier(_b_saCenters.value(), sawp, saww, 165 entryNameCol.stringValue(), 166 kmeansCol.stringValue()); 167 168 169 } catch (Exception fe) { 170 System.err.println("Cannot open SA Cluster ID file!"); 171 throw new IllegalActionException(""); // workflow exception 172 // TODO: better way to handle file error? 173 } 174 } 175 176 177 @Override 178 public void fire() throws IllegalActionException { 179 180 super.fire(); 181 182 // get input Dataset<Row> of cluster IDs w/station names 183 Dataset<Row> inDf = 184 (Dataset<Row>)((ObjectToken)inData.get(0)).getValue(); 185 186 // Encoder to serialize the row output of the map call 187 Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class); 188 189 // call _classifier function on input 190 // NOTE: if _classifier throws exception, don't do anything 191 try{ 192 // get classification dataframe via SA_Classifier 193 Dataset<Row> outDf = inDf.map(_classifier, rowEncoder); 194 195 StructType outSchema = inDf.schema(); 196 outSchema = outSchema.add(clusterIdCol.stringValue().trim(), 197 DataTypes.IntegerType); // clusterId 198 outSchema = outSchema.add(isSACol.stringValue().trim(), 199 DataTypes.BooleanType); // isSA 200 //TODO: pending parimeterizing below 201 outSchema = outSchema.add("cdist", DataTypes.DoubleType); // cdist 202 outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(), outSchema); 203 204 /* 205 if(_debugging) 206 outDf.printSchema(); //DEBUG 207 */ 208 209 outData.broadcast(new ObjectToken(outDf, Dataset.class)); 210 211 } catch (Exception e) { 212 System.err.println("ERROR: SA_Classifier failed! " + e.toString()); 213 } 214 215 } 216 217 218 219 // private class for _classifier function 220 private static class SA_Classifier implements MapFunction<Row,Row> { 221 222 //TODO: change to static vars by using serializeable classes? 223 private Map<String,Integer> saCenters; 224 private int saCounter; // counter for consecutive SA data points 225 private boolean[] window; 226 private double sawp; // SA window percentage 227 private int saww, windowIndex; // SA Window width, window index 228 private String idCol, kmeansCol; 229 230 public SA_Classifier(Map<String,Integer> centers, double sawp, int saww, 231 String idCol, String kmeansCol) { 232 this.saCenters = centers; 233 this.sawp = sawp; 234 this.saww = saww; 235 this.idCol = idCol; 236 this.kmeansCol = kmeansCol; 237 saCounter = 0; 238 window = new boolean[saww]; 239 windowIndex = 0; 240 } 241 242 public void updateWindowParam(double sawp, int saww) { 243 this.sawp = sawp; 244 this.saww = saww; 245 } 246 247 public Row call(Row datum) throws Exception { 248 249 // get entry name 250 String name; 251 try { 252 name = datum.getString(datum.fieldIndex(idCol)); 253 name = name.replaceAll("\\s+", ""); // clean up name string 254 } catch(IllegalArgumentException e) { 255 System.err.println("ERROR: Can't find entry column: " + idCol); 256 throw e; 257 } 258 259 // get Santa Ana center ID for current row's station 260 int saCenterId; 261 try { 262 saCenterId = saCenters.get(name).intValue(); 263 } catch(NullPointerException e) { 264 System.err.println("ERROR: No SA center info for " + 265 "station: " + name); 266 throw e; 267 } 268 269 270 // parse input row's kmeans vector to find cluster id (min dist) 271 double[] distArray; 272 try { 273 distArray = ((DenseVector)datum.get(datum.fieldIndex(kmeansCol))).toArray(); 274 } catch(IllegalArgumentException e) { 275 System.err.println("ERROR: Can't find kmeansCol: "+kmeansCol); 276 throw e; 277 } 278 int rowClusterId = -1; // cluster w/min dist in input row 279 double minDist = Double.MAX_VALUE; 280 for(int i = 0; i < distArray.length; i++) 281 if(distArray[i] < minDist) { 282 rowClusterId = i; 283 minDist = distArray[i]; 284 } 285 286 287 // compare current cluster center index with train data 288 /*System.out.printf("DEBUG: saCenterId=%d, rowClusterId=%d\n", 289 saCenterId, rowClusterId);*/ 290 if(saCenterId == rowClusterId) // is SA cluster 291 window[windowIndex] = true; 292 else // not SA cluster 293 window[windowIndex] = false; 294 windowIndex = windowIndex >= saww-1 ? 0 : windowIndex+1; 295 296 297 // count SA data points in current window 298 saCounter = 0; 299 for(boolean point : window) 300 saCounter = point ? saCounter + 1 : saCounter; 301 // gate for number of data points needed to identify SA cond. 302 boolean isSA; 303 if (saCounter >= (int)(saww * sawp)) isSA = true; 304 else isSA = false; 305 306 // catch case: above criteria met but current point is not SA 307 if (saCenterId != rowClusterId) isSA = false; 308 309 310 // Assemble output row: append clusterId and isSA fields 311 Object[] outFields = new Object[datum.size()+3]; 312 for(int i = 0; i < datum.size(); i++) 313 outFields[i] = datum.get(i); 314 outFields[datum.size()] = rowClusterId; 315 outFields[datum.size()+1] = isSA; 316 if(saCenterId > -1) 317 outFields[datum.size()+2] = distArray[saCenterId]; 318 else 319 outFields[datum.size()+2] = (double)-1.0; 320 321 return RowFactory.create(outFields); 322 } 323 } 324 325 326 327 /** Dataframe input */ 328 public TypedIOPort inData; 329 330 /** Dataframe output */ 331 public TypedIOPort outData; 332 333 /** Fraction threshold for the sliding window filter */ 334 public Parameter saWindowThrs; 335 336 /** Width of the sliding window filter */ 337 public Parameter saWindowWidth; 338 339 /** Path to SA cluster ID file */ 340 public StringParameter saFilePath; 341 342 /** Name of the entry name column in the input dataframe */ 343 public StringParameter entryNameCol; 344 345 /** Name of the kmeans vector column in the input dataframe */ 346 public StringParameter kmeansCol; 347 348 /** Name of the cluster ID column to append */ 349 public StringParameter clusterIdCol; 350 351 /** Name of the boolean 'is Santa Ana Condition' column to append */ 352 public StringParameter isSACol; 353 354 355 private JsonObject _fileObj; 356 private Broadcast<Map<String,Integer>> _b_saCenters; 357 private SA_Classifier _classifier; 358 359 360}