001/* 
002 * Copyright (c) 2016-2017 The Regents of the University of California.
003 * All rights reserved.
004 *
005 * '$Author: crawl $'
006 * '$Date: 2018-02-07 01:00:17 +0000 (Wed, 07 Feb 2018) $' 
007 * '$Revision: 34658 $'
008 * 
009 * Permission is hereby granted, without written agreement and without
010 * license or royalty fees, to use, copy, modify, and distribute this
011 * software and its documentation for any purpose, provided that the above
012 * copyright notice and the following two paragraphs appear in all copies
013 * of this software.
014 *
015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
019 * SUCH DAMAGE.
020 *
021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
026 * ENHANCEMENTS, OR MODIFICATIONS.
027 *
028 */
029
030package org.kepler.spark.mllib;
031
032import java.io.BufferedReader;
033import java.io.FileInputStream;
034import java.io.InputStreamReader;
035import java.io.PrintWriter;
036import java.util.Arrays;
037import java.util.Set;
038
039import org.apache.spark.ml.feature.StandardScaler;
040import org.apache.spark.ml.feature.StandardScalerModel;
041import org.apache.spark.sql.Dataset;
042import org.apache.spark.sql.Row;
043import org.kepler.spark.actor.SparkSQLActor;
044
045import io.vertx.core.json.JsonArray;
046import io.vertx.core.json.JsonObject;
047import ptolemy.actor.TypedIOPort;
048import ptolemy.data.ObjectToken;
049import ptolemy.data.StringToken;
050import ptolemy.data.expr.StringParameter;
051import ptolemy.data.type.BaseType;
052import ptolemy.data.type.ObjectType;
053import ptolemy.kernel.CompositeEntity;
054import ptolemy.kernel.util.IllegalActionException;
055import ptolemy.kernel.util.NameDuplicationException;
056import ptolemy.kernel.util.SingletonAttribute;
057
058/**
059 * @author Dylan Uys, Jiaxin Li
060 *
061 * Standardizes the values in the incoming DataFrame by subtracting the mean 
062 * and dividing by the standard deviation. Because StandardScaler acts on a 
063 * vector column (a column in the DataFrame containing Vectors whose values 
064 * are representative of all the values in its respective Row), VectorAssembler
065 * is used to create this column and append it to the incoming DataFrame prior
066 * to scaling the values. Once StandardScaler transforms the DataFrame (the one
067 * with the newly created vector column), another vector column called 
068 * "scaledFeatures" is appended, containing Vectors representative of the 
069 * normalized values from the "features" column.
070 */
071public class Standardize extends SparkSQLActor {
072
073    public Standardize(CompositeEntity container, String name)
074        throws IllegalActionException, NameDuplicationException {
075        super(container, name);
076        
077        inData = new TypedIOPort(this, "inData", true, false);
078        inData.setTypeEquals(new ObjectType(Dataset.class));
079        new SingletonAttribute(inData, "_showName");
080        
081        outData = new TypedIOPort(this, "outData", false, true);
082        outData.setTypeEquals(new ObjectType(Dataset.class));
083        new SingletonAttribute(outData, "_showName");
084        
085        firstPort =  new TypedIOPort(this, "firstPort", false, true);
086        firstPort.setTypeEquals(new ObjectType(Row.class));
087        new SingletonAttribute(firstPort, "_showName");
088
089        inColName = new StringParameter(this, "inColName");
090        inColName.setToken("in_vector");
091
092        stdColName = new StringParameter(this, "stdColName");
093        stdColName.setToken("std_vector");
094
095        entryName = new TypedIOPort(this, "entryName", true, false);
096        entryName.setTypeEquals(BaseType.STRING);
097        new SingletonAttribute(entryName, "_showName");
098
099        entryNameCol = new StringParameter(this, "entryNameCol");
100        entryNameCol.setToken("name");
101        
102        outFilepath = new StringParameter(this, "outFilePath");
103        outFilepath.setToken("meanstd.json");
104    }
105
106
107    @Override
108    public void initialize() throws IllegalActionException {
109
110        super.initialize();
111        
112        // initialize JsonObject to exist throughout actor's lifetime
113        _fileObj = new JsonObject();
114    }
115
116    
117    /*
118     * This actor standardizes the values in a Dataset, and outputs a Dataset 
119     * containing a column 'scaledFeatures' containing a vector of the 
120     * standardize feature values.
121     */
122    @Override
123    public void fire() throws IllegalActionException {
124        
125        super.fire();
126        
127        // Read input data frame 
128        Dataset<Row> df = 
129            (Dataset<Row>)((ObjectToken)inData.get(0)).getValue();
130
131
132        // get entryId for saving current mean/std stats to file
133        // NOTE: disallowing pure whitespace names
134        String entryId = "";
135        if(entryName.getWidth() > 0) // entryName port overrides entryNameCol
136            entryId = ((StringToken)entryName.get(0)).stringValue().trim(); 
137        else {
138            String col = entryNameCol.stringValue().trim();
139            if(col == "")
140                System.out.println("No entryName/Col, not saving to file!");
141            else try {
142                    Row fr = df.first();
143                    entryId = (String) fr.getString(fr.fieldIndex(col));
144                } catch(IllegalArgumentException e) {
145                    throw new IllegalActionException("Can't find col name!");
146                } catch(UnsupportedOperationException e) {
147                    throw new IllegalActionException("No schema in input df!");
148                } catch(ClassCastException e) {
149                    throw new IllegalActionException("Col data not strings!");
150                }
151        }
152        entryId = entryId.replaceAll("\\s+", "");  // remove all whitespaces
153
154
155        // create StandardScaler  
156        StandardScaler scaler = new StandardScaler()
157            .setWithStd(true)
158            .setWithMean(true)
159            .setInputCol(inColName.stringValue().trim())
160            .setOutputCol(stdColName.stringValue().trim());
161        // Perform standardization by subtracting mean and dividing by stdev. 
162        // The transform operation results in the scaledFeatures column being 
163        // appended to the DataFrame. This column contains vectorUDTs of the 
164        // standardized values found in the corresponding vectorUDT's in the 
165        // features column
166        StandardScalerModel scalerModel = scaler.fit(df.cache()); 
167        df = scalerModel.transform(df);
168        df = df.drop(inColName.stringValue().trim()); // drop feature column
169
170
171        // Output the standardized dataframe
172        outData.broadcast(new ObjectToken(df, Dataset.class));
173
174        
175        /* Save (Vector) scalerModel.mean() and scalerModel.std() as JSON */
176        if(entryId != "") {
177            JsonObject entryObj = new JsonObject();
178            JsonArray meanArr = 
179                new JsonArray(Arrays.toString(scalerModel.mean().toArray()));
180            JsonArray stdArr = 
181                new JsonArray(Arrays.toString(scalerModel.std().toArray()));
182            entryObj.put("mean", meanArr);
183            entryObj.put("std", stdArr);
184            _fileObj.put(entryId, entryObj);
185        
186            //TODO: 
187            //scalerModel.save("TODO");
188        }
189    }
190    
191
192    /*
193     * Wrap-up phase:
194     * writes _fileObj to buffer, then flush and close the file.
195     * NOTE: opens file only after the workflow is complete 
196     */
197    @Override
198    public void wrapup() throws IllegalActionException {
199
200        super.wrapup();
201
202        PrintWriter writer; 
203
204        // try to read in existing file and overwrite it with new data. If no 
205        // original files exist, create a new file. 
206        try {
207
208            // read in existing file, is possible
209            FileInputStream is =
210                new FileInputStream(outFilepath.getValueAsString());
211            BufferedReader reader =
212                new BufferedReader(new InputStreamReader(is));
213
214            // read in the entire file as a single string, for JsonObject
215            StringBuilder sb = new StringBuilder();
216            String line = reader.readLine();
217            while (line != null) {
218                sb.append(line);
219                line = reader.readLine();
220            }
221            reader.close(); // close file for writing
222
223            // file read in, construct JSON object for contents in orig file
224            JsonObject origFileObj = new JsonObject(sb.toString());
225            // get list of all station names from the new JSON obj w/new data
226            Set<String> stationList = _fileObj.fieldNames();
227
228            // replace old station data with new data
229            for (String station : stationList) {
230                origFileObj.remove(station);  // ignored if station DNE
231                origFileObj.put(station, _fileObj.getJsonObject(station));
232            }
233
234            // write updated origFileObj to file.
235            try {
236                // open file
237                writer = new PrintWriter(outFilepath.getValueAsString());
238                // write _fileObj
239                writer.println(origFileObj.encodePrettily()); // updated data
240                // flush and close the writer's file
241                writer.flush();
242                writer.close();
243            }catch(Exception ei1){System.err.println("Failed to open file!");}
244
245        } catch (Exception fe) {  // if error, create/write to new file
246
247            System.out.println("No original data file found."); 
248
249            try {
250                // open file
251                writer = new PrintWriter(outFilepath.getValueAsString());
252                // write _fileObj
253                writer.println(_fileObj.encodePrettily());
254                // flush and close the writer's file
255                writer.flush();
256                writer.close();
257            }catch(Exception ei2){System.err.println("Failed to open file!");}
258        }
259        
260    }
261    
262    
263    /** Input DataFrame. */
264    public TypedIOPort inData;
265    
266    /** Standardized data set as an RDD of Vectors. */
267    public TypedIOPort outData;
268    
269    /** Outputs the first row of the new dataset only for test purpose*/
270    public TypedIOPort firstPort;
271
272    /** Columns to ignore (drop) before standardization. */
273    public StringParameter inColName;
274
275    /** Name of the output column containing standardized vectors */
276    public StringParameter stdColName; 
277
278    /** User-specified entry name (i.e. "ID" of the current data frame) */
279    public TypedIOPort entryName;
280
281    /** Name of column containing entry name */
282    public StringParameter entryNameCol; 
283
284    /** Parameter to store string for output JSON file */
285    public StringParameter outFilepath;
286    
287    /* PrintWriter and top-level JsonObject */
288    private JsonObject _fileObj;
289        
290}