001/* 
002 * Copyright (c) 2016-2017 The Regents of the University of California.
003 * All rights reserved.
004 *
005 * '$Author: crawl $'
006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 
007 * '$Revision: 34661 $'
008 * 
009 * Permission is hereby granted, without written agreement and without
010 * license or royalty fees, to use, copy, modify, and distribute this
011 * software and its documentation for any purpose, provided that the above
012 * copyright notice and the following two paragraphs appear in all copies
013 * of this software.
014 *
015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
019 * SUCH DAMAGE.
020 *
021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
026 * ENHANCEMENTS, OR MODIFICATIONS.
027 *
028 */
029
030package org.kepler.spark.mllib;
031
032import java.io.BufferedReader;
033import java.io.FileInputStream;
034import java.io.InputStreamReader;
035import java.util.HashMap;
036import java.util.List;
037import java.util.Map;
038import java.util.Set;
039
040import org.apache.spark.api.java.function.MapFunction;
041import org.apache.spark.broadcast.Broadcast;
042import org.apache.spark.ml.linalg.DenseVector;
043import org.apache.spark.ml.linalg.SQLDataTypes;
044import org.apache.spark.ml.linalg.Vector;
045import org.apache.spark.ml.linalg.Vectors;
046import org.apache.spark.sql.Dataset;
047import org.apache.spark.sql.Encoder;
048import org.apache.spark.sql.Encoders;
049import org.apache.spark.sql.Row;
050import org.apache.spark.sql.RowFactory;
051import org.apache.spark.sql.types.StructType;
052import org.kepler.spark.actor.SparkSQLActor;
053
054import io.vertx.core.json.JsonArray;
055import io.vertx.core.json.JsonObject;
056import ptolemy.actor.TypedIOPort;
057import ptolemy.data.ObjectToken;
058import ptolemy.data.StringToken;
059import ptolemy.data.expr.StringParameter;
060import ptolemy.data.type.BaseType;
061import ptolemy.data.type.ObjectType;
062import ptolemy.kernel.CompositeEntity;
063import ptolemy.kernel.util.IllegalActionException;
064import ptolemy.kernel.util.NameDuplicationException;
065import ptolemy.kernel.util.SingletonAttribute;
066
067/**
068 * @author Dylan Uys, Jiaxin Li
069 *
070 * Standardizes the values in the incoming DataFrame by subtracting the mean 
071 * and dividing by the standard deviation.
072 */
073public class StandardizeApply extends SparkSQLActor {
074
075    public StandardizeApply(CompositeEntity container, String name)
076        throws IllegalActionException, NameDuplicationException {
077        super(container, name);
078        
079        inData = new TypedIOPort(this, "inData", true, false);
080        inData.setTypeEquals(new ObjectType(Dataset.class));
081        new SingletonAttribute(inData, "_showName");
082        
083        outData = new TypedIOPort(this, "outData", false, true);
084        outData.setTypeEquals(new ObjectType(Dataset.class));
085        new SingletonAttribute(outData, "_showName");
086        
087        test =  new TypedIOPort(this, "test", false, true);
088        test.setTypeEquals(BaseType.STRING);
089        new SingletonAttribute(test, "_showName");
090
091        inFilepath = new StringParameter(this, "inFilepath");
092        inFilepath.setToken("meanstd.json");
093
094        entryNameCol = new StringParameter(this, "entryNameCol");
095        entryNameCol.setToken("name");
096        
097        stdCol = new StringParameter(this, "stdCol");
098        stdCol.setToken("std_vector");
099        
100        inCol = new StringParameter(this, "inCol");
101        inCol.setToken("in_vector");
102    }
103
104    @Override
105    public void initialize() throws IllegalActionException {
106        
107        super.initialize();
108
109        // read model file
110        StringBuilder sb; 
111        try {
112            FileInputStream is = new FileInputStream(inFilepath.stringValue());
113            BufferedReader reader =
114                new BufferedReader(new InputStreamReader(is));
115            // read in the entire file as a single string, for JsonObject
116            sb = new StringBuilder();
117            String line = reader.readLine();
118            while (line != null) {
119                sb.append(line);
120                line = reader.readLine();
121            }
122            reader.close(); // close file
123        } catch(Exception e) {
124            throw new IllegalActionException("ERROR: Cannot read file!");
125        }
126        
127        
128        // Create json object from file string, extract station specific data
129        JsonObject jsonObj = new JsonObject(sb.toString());
130        Set<String> stations = jsonObj.fieldNames();
131
132        // These will be the broadcasted objects. The keys are String station
133        // codes, the values are List<Double>'s of means/stddevs
134        Map<String, List<Double>> meansByStation =
135            new HashMap<String, List<Double>>();
136        Map<String, List<Double>> stdsByStation =
137            new HashMap<String, List<Double>>();
138
139        // Iterate through the stationCodes present in the JsonObject's keys,
140        // adding each respective array of means/stddevs to the Map
141        // instantiated above
142        for (String stationCode: stations) {
143
144            // stationObj format:  "MW": [[...], [...], ...]
145            JsonObject stationObj = jsonObj.getJsonObject(stationCode);
146            JsonArray means = stationObj.getJsonArray("mean");
147            JsonArray stds = stationObj.getJsonArray("std");
148
149            List<Double> meansList = means.getList();
150            List<Double> stdsList = stds.getList();
151
152            stdsByStation.put(stationCode, stdsList);
153            meansByStation.put(stationCode, meansList);
154        }
155        
156        // Broadcast
157        _bcastMean = _context.broadcast(meansByStation);
158        _bcastStddev = _context.broadcast(stdsByStation);
159    }
160
161
162    @Override
163    public void fire() throws IllegalActionException {
164        
165        super.fire();
166
167        // get name column and input column
168        String inColName = inCol.stringValue().trim();
169        String nameCol = entryNameCol.stringValue().trim(); 
170        
171        // Read data
172        Dataset<Row> inDf =
173            (Dataset<Row>)((ObjectToken)inData.get(0)).getValue();
174        Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class);
175
176        // DEBUG
177        /*
178        if(_debugging) {
179            System.out.println("Standardize Before:");
180            inDf.printSchema();
181        }
182        */
183
184        Map<String,List<Double>> mean = _bcastMean.value();;
185        Map<String,List<Double>> std = _bcastStddev.value();
186
187        // if Standadizer throws exception, don't output any token
188        try{
189
190            Standardizer stder = new Standardizer(mean,std,inColName,nameCol);
191            Dataset<Row> outDf = inDf.map(stder, rowEncoder);
192
193            // Adding schema back to dataframe, after map function
194            // NOTE: new schema contains one more column, std_vector
195            StructType outSchema = inDf.schema();
196            outSchema = outSchema.add(stdCol.stringValue(),
197                                      SQLDataTypes.VectorType());
198            outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(),
199                                                outSchema);
200
201            //DEBUG
202            /*
203            if(_debugging) {
204                System.out.println("Standardize After:");
205                outDf.printSchema();
206            }
207            */
208
209            // Output standardized dataframe
210            outData.broadcast(new ObjectToken(outDf, Dataset.class));
211            if (test.numberOfSinks() > 0) {
212                test.broadcast(new StringToken(outDf.first().toString()));
213            }
214            
215        } catch(Exception e) {
216            System.err.println("ERROR: Standardize Apply failed! " +
217                               e.toString());
218        }
219
220    }
221
222    /* 
223     * Private helper class for standardizing rows using map() 
224     */
225    private static class Standardizer implements MapFunction<Row, Row> {
226
227        private Map<String,List<Double>> meanList;
228        private Map<String,List<Double>> stdList;
229        private String inCol;
230        private String idCol;
231        
232        public Standardizer(Map<String,List<Double>> mean,
233                            Map<String,List<Double>> std, String inc,
234                            String idc) {
235            meanList = mean;
236            stdList = std;
237            inCol = inc;
238            idCol = idc;
239        }
240
241        @Override
242        public Row call(Row in) throws Exception {
243
244            // get entry ID (e.g. station ID) from input row
245            String name = in.getString(in.fieldIndex(idCol));
246            name = name.replaceAll("\\s+", "");  // clean up name string
247
248            // TODO: check for dimension mismatch
249
250            // get all data fields from the row, then standardize
251            // NOTE: requires identical vec order b/w model and input vector
252            try {
253                // get input (raw) vector as double[] from input row
254                DenseVector inVector;
255                try{
256                    inVector = (DenseVector)in.get(in.fieldIndex(inCol));
257                } catch(IllegalArgumentException e) {
258                    System.err.println("ERROR: Can't find inCol: " + inCol);
259                    throw e;
260                }
261                double[] inVec = inVector.toArray();
262
263                // create output vector and standardize from input vector
264                double[] out = new double[inVector.size()];
265                for(int i = 0; i < inVector.size(); i++) {
266                    double m = meanList.get(name).get(i);
267                    double s = stdList.get(name).get(i);
268                    out[i] = inVec[i];
269                    out[i] = out[i] - m;
270                    out[i] = out[i] / s;
271                }
272
273                // assemble the standardized feature vector
274                Vector stdVector = Vectors.dense(out);
275
276                // Assemble output row 
277                Object[] outFields = new Object[in.size()+1];
278                for(int i = 0; i < in.size(); i++) // copy existing fields
279                    outFields[i] = in.get(i); 
280                outFields[in.size()] = stdVector; // append vector field
281                
282                return RowFactory.create(outFields);
283
284            } catch (NullPointerException e) {
285                // catching the case where station name from the websocket
286                // can't be found in the train data
287                System.err.println("ERROR: Cannot find mean/std info:" + name);
288                throw e;
289
290            } catch (Exception e) {
291                System.err.println("ERROR: Standardizer failed!"+e.toString());
292                throw e;
293            }
294                           
295        }       
296    }
297
298    /** Dataframe input */
299    public TypedIOPort inData;
300
301    /** Dataframe output */
302    public TypedIOPort outData;   // Standardized data in Dataset
303
304    /** First row of input dataframe, for testing */
305    public TypedIOPort test;      // Outputs the first row of new dataset
306
307    /** Path to model file */
308    public StringParameter inFilepath;
309
310    /** Name of the input vector column, which contains raw data */
311    public StringParameter inCol;
312
313    /** Name of entry name column */
314    public StringParameter entryNameCol;
315
316    /** Name of standardized vector column to append */
317    public StringParameter stdCol;
318
319    
320    private Broadcast<Map<String,List<Double>>> _bcastMean;
321    private Broadcast<Map<String,List<Double>>> _bcastStddev;
322    
323}