001/* 
002 * Copyright (c) 2016-2017 The Regents of the University of California.
003 * All rights reserved.
004 *
005 * '$Author: crawl $'
006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 
007 * '$Revision: 34661 $'
008 * 
009 * Permission is hereby granted, without written agreement and without
010 * license or royalty fees, to use, copy, modify, and distribute this
011 * software and its documentation for any purpose, provided that the above
012 * copyright notice and the following two paragraphs appear in all copies
013 * of this software.
014 *
015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
019 * SUCH DAMAGE.
020 *
021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
026 * ENHANCEMENTS, OR MODIFICATIONS.
027 *
028 */
029
030package org.kepler.spark.mllib;
031
032import java.io.BufferedReader;
033import java.io.FileInputStream;
034import java.io.InputStreamReader;
035import java.util.HashMap;
036import java.util.Map;
037import java.util.Set;
038
039import org.apache.spark.api.java.function.MapFunction;
040import org.apache.spark.broadcast.Broadcast;
041import org.apache.spark.ml.linalg.DenseVector;
042import org.apache.spark.sql.Dataset;
043import org.apache.spark.sql.Encoder;
044import org.apache.spark.sql.Encoders;
045import org.apache.spark.sql.Row;
046import org.apache.spark.sql.RowFactory;
047import org.apache.spark.sql.types.DataTypes;
048import org.apache.spark.sql.types.StructType;
049import org.kepler.spark.actor.SparkSQLActor;
050
051import io.vertx.core.json.JsonObject;
052import ptolemy.actor.TypedIOPort;
053import ptolemy.data.DoubleToken;
054import ptolemy.data.IntToken;
055import ptolemy.data.ObjectToken;
056import ptolemy.data.expr.Parameter;
057import ptolemy.data.expr.StringParameter;
058import ptolemy.data.type.BaseType;
059import ptolemy.data.type.ObjectType;
060import ptolemy.kernel.CompositeEntity;
061import ptolemy.kernel.util.IllegalActionException;
062import ptolemy.kernel.util.NameDuplicationException;
063import ptolemy.kernel.util.SingletonAttribute;
064
065
066/**
067 * @author Jiaxin Li
068 *
069 * An actor specifically written for the WIFIRE project. 
070 * The actor finds the cluster center the input measurements match to most 
071 * closely, basing on pre-trained models, and decides if the measurement 
072 * reflects Santa Ana condition.  
073 */
074public class SantaAnaDetect extends SparkSQLActor {
075
076    public SantaAnaDetect(CompositeEntity container, String name)
077        throws IllegalActionException, NameDuplicationException {
078
079        super(container, name);
080
081        inData = new TypedIOPort(this, "inData", true, false);
082        inData.setTypeEquals(new ObjectType(Dataset.class));
083        new SingletonAttribute(inData, "_showName");
084
085        saWindowThrs = new Parameter(this, "saWindowThrs");
086        saWindowThrs.setTypeEquals(BaseType.DOUBLE);
087        saWindowThrs.setToken(new DoubleToken(0.9));
088
089        saWindowWidth = new Parameter(this, "saWindowWidth");
090        saWindowWidth.setTypeEquals(BaseType.INT);
091        saWindowWidth.setToken(new IntToken(10));
092        
093        // path to SA Cluster ID file 
094        saFilePath = new StringParameter(this, "saFilePath");
095        saFilePath.setToken("saClusters.json");
096
097        entryNameCol = new StringParameter(this, "entryNameCol");
098        entryNameCol.setToken("name");
099
100        kmeansCol = new StringParameter(this, "kmeansCol");
101        kmeansCol.setToken("kmeans_dist_vec");
102
103        clusterIdCol = new StringParameter(this, "clusterIdCol");
104        clusterIdCol.setToken("clusterId");
105
106        isSACol = new StringParameter(this, "isSACol");
107        isSACol.setToken("isSA");
108
109        // output results in string format (can be displayed) 
110        outData = new TypedIOPort(this, "outData", false, true);
111        outData.setTypeEquals(new ObjectType(Dataset.class));
112        new SingletonAttribute(outData, "_showName");
113
114    }
115
116    /* initialize:
117     */
118    @Override
119    public void initialize() throws IllegalActionException {
120
121        super.initialize();
122
123        // initialize _fileObj
124        _fileObj = new JsonObject();
125
126        // try to open SA cluster file
127        try{
128            // open file
129            FileInputStream is =
130                new FileInputStream(saFilePath.getValueAsString());
131            BufferedReader reader =
132                new BufferedReader(new InputStreamReader(is));
133
134            // read the file as a single string, for JsonObject(String)
135            StringBuilder sb = new StringBuilder();
136            String line = reader.readLine();
137            while (line != null) {
138                sb.append(line);
139                line = reader.readLine();
140            }
141            reader.close();  // close file 
142
143            // file read, build new JSON object 
144            _fileObj = new JsonObject(sb.toString());
145
146            // get list of station names
147            Set<String> stationNames = _fileObj.fieldNames();
148
149            // build new hashmap for [stationName,clusterId] pairs,
150            //   then broadcast to spark
151            Map<String,Integer> saCenters = new HashMap<String,Integer>();
152            for(String stationName : stationNames) {
153                saCenters.put(stationName,
154                              _fileObj.getJsonObject(stationName)
155                              .getInteger("saCluster"));
156            }
157            _b_saCenters = _context.broadcast(saCenters);
158
159            // get detection window parameters
160            double sawp = ((DoubleToken)saWindowThrs.getToken()).doubleValue();
161            int saww = ((IntToken)saWindowWidth.getToken()).intValue();
162
163            // NOTE: creating _classifier here so that counters can persist
164            _classifier = new SA_Classifier(_b_saCenters.value(), sawp, saww,
165                                          entryNameCol.stringValue(),
166                                          kmeansCol.stringValue()); 
167
168
169        } catch (Exception fe) {
170            System.err.println("Cannot open SA Cluster ID file!");
171            throw new IllegalActionException("");  // workflow exception
172            // TODO: better way to handle file error? 
173        }
174    }
175
176
177    @Override
178    public void fire() throws IllegalActionException {
179
180        super.fire();
181
182        // get input Dataset<Row> of cluster IDs w/station names
183        Dataset<Row> inDf =
184            (Dataset<Row>)((ObjectToken)inData.get(0)).getValue();
185
186        // Encoder to serialize the row output of the map call
187        Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class);
188
189        // call _classifier function on input
190        // NOTE: if _classifier throws exception, don't do anything
191        try{
192            // get classification dataframe via SA_Classifier
193            Dataset<Row> outDf = inDf.map(_classifier, rowEncoder);
194
195            StructType outSchema = inDf.schema();
196            outSchema = outSchema.add(clusterIdCol.stringValue().trim(),
197                                      DataTypes.IntegerType); // clusterId
198            outSchema = outSchema.add(isSACol.stringValue().trim(),
199                                      DataTypes.BooleanType); // isSA
200            //TODO: pending parimeterizing below
201            outSchema = outSchema.add("cdist", DataTypes.DoubleType); // cdist
202            outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(), outSchema);
203
204            /*
205            if(_debugging)
206                outDf.printSchema(); //DEBUG
207                */
208            
209            outData.broadcast(new ObjectToken(outDf, Dataset.class));
210            
211        } catch (Exception e) {
212            System.err.println("ERROR: SA_Classifier failed! " + e.toString());
213        }
214
215    }
216
217
218
219    // private class for _classifier function
220    private static class SA_Classifier implements MapFunction<Row,Row> {
221
222        //TODO: change to static vars by using serializeable classes? 
223        private Map<String,Integer> saCenters;
224        private int saCounter;  // counter for consecutive SA data points
225        private boolean[] window;
226        private double sawp;  // SA window percentage
227        private int saww, windowIndex;  // SA Window width, window index
228        private String idCol, kmeansCol; 
229        
230        public SA_Classifier(Map<String,Integer> centers, double sawp, int saww,
231                            String idCol, String kmeansCol) {
232            this.saCenters = centers;
233            this.sawp = sawp;
234            this.saww = saww;
235            this.idCol = idCol;
236            this.kmeansCol = kmeansCol; 
237            saCounter = 0;
238            window = new boolean[saww];
239            windowIndex = 0;
240        }
241
242        public void updateWindowParam(double sawp, int saww) {
243            this.sawp = sawp;
244            this.saww = saww;
245        }
246
247        public Row call(Row datum) throws Exception {
248
249            // get entry name
250            String name;
251            try {
252                name = datum.getString(datum.fieldIndex(idCol));
253                name = name.replaceAll("\\s+", "");  // clean up name string
254            } catch(IllegalArgumentException e) {
255                System.err.println("ERROR: Can't find entry column: " + idCol);
256                throw e;
257            }
258
259            // get Santa Ana center ID for current row's station
260            int saCenterId;
261            try {
262                saCenterId = saCenters.get(name).intValue();    
263            } catch(NullPointerException e) {
264                System.err.println("ERROR: No SA center info for " +
265                                   "station: " + name);
266                throw e;
267            }
268
269
270            // parse input row's kmeans vector to find cluster id (min dist)
271            double[] distArray; 
272            try {
273                distArray = ((DenseVector)datum.get(datum.fieldIndex(kmeansCol))).toArray();
274            } catch(IllegalArgumentException e) {
275                System.err.println("ERROR: Can't find kmeansCol: "+kmeansCol);
276                throw e; 
277            }
278            int rowClusterId = -1; // cluster w/min dist in input row
279            double minDist = Double.MAX_VALUE;
280            for(int i = 0; i < distArray.length; i++)
281                if(distArray[i] < minDist) {
282                    rowClusterId = i;
283                    minDist = distArray[i];
284                }
285
286
287            // compare current cluster center index with train data
288            /*System.out.printf("DEBUG: saCenterId=%d, rowClusterId=%d\n",
289              saCenterId, rowClusterId);*/
290            if(saCenterId == rowClusterId) // is SA cluster
291                window[windowIndex] = true;
292            else // not SA cluster
293                window[windowIndex] = false;
294            windowIndex = windowIndex >= saww-1 ? 0 : windowIndex+1;
295            
296            
297            // count SA data points in current window
298            saCounter = 0;
299            for(boolean point : window)
300                saCounter = point ? saCounter + 1 : saCounter;      
301            // gate for number of data points needed to identify SA cond.
302            boolean isSA;
303            if (saCounter >= (int)(saww * sawp)) isSA = true;
304            else isSA = false;
305            
306            // catch case: above criteria met but current point is not SA
307            if (saCenterId != rowClusterId) isSA = false; 
308
309
310            // Assemble output row: append clusterId and isSA fields
311            Object[] outFields = new Object[datum.size()+3];
312            for(int i = 0; i < datum.size(); i++)
313                outFields[i] = datum.get(i);
314            outFields[datum.size()] = rowClusterId;
315            outFields[datum.size()+1] = isSA;
316            if(saCenterId > -1)
317                outFields[datum.size()+2] = distArray[saCenterId];
318            else
319                outFields[datum.size()+2] = (double)-1.0;
320
321            return RowFactory.create(outFields);
322        }
323    }
324        
325
326        
327    /** Dataframe input */
328    public TypedIOPort inData;
329
330    /** Dataframe output */
331    public TypedIOPort outData;
332
333    /** Fraction threshold for the sliding window filter */
334    public Parameter saWindowThrs;
335
336    /** Width of the sliding window filter */
337    public Parameter saWindowWidth;
338
339    /** Path to SA cluster ID file */
340    public StringParameter saFilePath;
341
342    /** Name of the entry name column in the input dataframe */
343    public StringParameter entryNameCol;
344
345    /** Name of the kmeans vector column in the input dataframe */
346    public StringParameter kmeansCol;
347
348    /** Name of the cluster ID column to append */
349    public StringParameter clusterIdCol;
350
351    /** Name of the boolean 'is Santa Ana Condition' column to append */
352    public StringParameter isSACol; 
353
354
355    private JsonObject _fileObj;
356    private Broadcast<Map<String,Integer>> _b_saCenters;
357    private SA_Classifier _classifier; 
358
359
360}