001/* 
002 * Copyright (c) 2016-2017 The Regents of the University of California.
003 * All rights reserved.
004 *
005 * '$Author: crawl $'
006 * '$Date: 2018-02-07 18:53:35 +0000 (Wed, 07 Feb 2018) $' 
007 * '$Revision: 34661 $'
008 * 
009 * Permission is hereby granted, without written agreement and without
010 * license or royalty fees, to use, copy, modify, and distribute this
011 * software and its documentation for any purpose, provided that the above
012 * copyright notice and the following two paragraphs appear in all copies
013 * of this software.
014 *
015 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
016 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
017 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
018 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
019 * SUCH DAMAGE.
020 *
021 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
022 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
023 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
024 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
025 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
026 * ENHANCEMENTS, OR MODIFICATIONS.
027 *
028 */
029
030package org.kepler.spark.mllib;
031
032import java.io.FileInputStream;
033import java.io.FileNotFoundException;
034import java.io.IOException;
035import java.io.InputStream;
036import java.io.InputStreamReader;
037import java.io.Reader;
038import java.io.UnsupportedEncodingException;
039import java.util.HashMap;
040import java.util.Map;
041import java.util.Set;
042
043import org.apache.commons.math3.ml.distance.EuclideanDistance;
044import org.apache.spark.api.java.function.MapFunction;
045import org.apache.spark.broadcast.Broadcast;
046import org.apache.spark.ml.linalg.DenseVector;
047import org.apache.spark.ml.linalg.SQLDataTypes;
048import org.apache.spark.ml.linalg.Vector;
049import org.apache.spark.ml.linalg.Vectors;
050import org.apache.spark.sql.Dataset;
051import org.apache.spark.sql.Encoder;
052import org.apache.spark.sql.Encoders;
053import org.apache.spark.sql.Row;
054import org.apache.spark.sql.RowFactory;
055import org.apache.spark.sql.types.StructType;
056import org.kepler.spark.actor.SparkSQLActor;
057
058import io.vertx.core.json.JsonArray;
059import io.vertx.core.json.JsonObject;
060import ptolemy.actor.TypedIOPort;
061import ptolemy.data.ObjectToken;
062import ptolemy.data.StringToken;
063import ptolemy.data.expr.StringParameter;
064import ptolemy.data.type.BaseType;
065import ptolemy.data.type.ObjectType;
066import ptolemy.kernel.CompositeEntity;
067import ptolemy.kernel.util.IllegalActionException;
068import ptolemy.kernel.util.NameDuplicationException;
069import ptolemy.kernel.util.SingletonAttribute;
070
071/**
072 * @author Dylan Uys, Jiaxin Li
073 *
074 * Takes a column of standardized vectors from the input dataframe, and 
075 * calculates the kmeans distance space vector (i.e. the distance from the 
076 * current row's measurement to all cluster centers in the model file). 
077 * Outputs results by appending a vector to the dataframe.  
078 */
079public class KMeansApply extends SparkSQLActor {
080
081    public KMeansApply(CompositeEntity container, String name)
082        throws IllegalActionException, NameDuplicationException {
083        
084        super(container, name);
085        
086        inData = new TypedIOPort(this, "data", true, false);
087        inData.setTypeEquals(new ObjectType(Dataset.class));
088        new SingletonAttribute(inData, "_showName");
089
090        outData = new TypedIOPort(this, "outData", false, true);  
091        outData.setTypeEquals(new ObjectType(Dataset.class));
092        new SingletonAttribute(outData, "_showName");  
093        
094        inFilepath = new StringParameter(this, "inFilepath");
095        inFilepath.setToken("clusterCenters.json");
096
097        entryNameCol = new StringParameter(this, "entryNameCol");
098        entryNameCol.setToken("name");
099
100        stdCol = new StringParameter(this, "stdCol");
101        stdCol.setToken("std_vector");
102
103        kmeansCol = new StringParameter(this, "kmeansCol");
104        kmeansCol.setToken("kmeans_dist_vec");
105
106        debugOutput = new TypedIOPort(this, "debugOutput", false, true);
107        debugOutput.setTypeEquals(BaseType.STRING);
108        new SingletonAttribute(debugOutput, "_showName");
109        
110    }
111
112    @Override
113    public void initialize() throws IllegalActionException  {
114
115        super.initialize();
116
117        InputStream is;
118        Reader isr;
119        StringBuilder sb;
120
121        // Read entire JSON file into one string, to pass to JsonObject ctor
122        try {
123            is = new FileInputStream(inFilepath.getValueAsString());
124            try {
125                isr = new InputStreamReader(is, "UTF-8");
126            } catch (UnsupportedEncodingException e) {
127                System.err.println("ERROR: Problem with file's encoding");
128                return;
129            }
130
131            // use StringBuilder to construct JSON string
132            sb = new StringBuilder(512);
133            try {
134                int c = 0;
135                while ((c = isr.read()) != -1) {
136                    sb.append((char) c);
137                }
138            } catch (IOException e) {
139                throw new RuntimeException(e);
140            }
141
142        } catch (FileNotFoundException e) {
143            throw new IllegalActionException(this, e, "ERROR: File not found:"
144                        + inFilepath.getValueAsString());  // workflow exception
145        }
146
147        // Create a JsonObject from the string, then get all the station keys
148        JsonObject jsonObj = new JsonObject(sb.toString());
149        Set<String> stations = jsonObj.fieldNames();
150
151        // This will be the broadcasted object. The keys are String station
152        // codes, the values are double[][] of cluster centers.
153        Map<String, double[][]> centersByStation =
154            new HashMap<String, double[][]>();
155
156        // Iterate through the stationCodes present in the JsonObject's keys,
157        // adding each respective array of cluster centers to the Map
158        // instantiated above
159        for (String stationCode: stations) {
160
161            // stationObj format:  "MW": [[...], [...], ...]
162            JsonObject stationObj = jsonObj.getJsonObject(stationCode);
163            JsonArray centers = stationObj.getJsonArray("clusterCenters");
164
165            double[][] centersList;
166            centersList = new double[centers.size()][centers.getJsonArray(0).size()];
167            for (int i = 0; i < centers.size(); i++) {
168                JsonArray currCenter = centers.getJsonArray(i);
169
170                for (int j = 0; j < currCenter.size(); j++) {
171                    centersList[i][j] = currCenter.getDouble(j);
172                }
173            }
174            
175            centersByStation.put(stationCode, centersList);
176        }
177
178        // broadcast
179        _bcastCenters = _context.broadcast(centersByStation);
180
181    }
182
183    @Override
184    public void fire() throws IllegalActionException  {
185        
186        super.fire();
187
188        Dataset<Row>inDf=(Dataset<Row>)((ObjectToken)inData.get(0)).getValue();
189
190        //DEBUG
191        /*
192        if(_debugging)
193            inDf.printSchema();
194        */
195        //inDf.show();
196
197        // Encoder to serialize the row output of the map call
198        Encoder<Row> rowEncoder = Encoders.javaSerialization(Row.class);
199
200        // Get broadcast var to pass to ClusterClassifier
201        // TODO: do we need to have this in fire()? 
202        Map<String, double[][]> stationCenters = _bcastCenters.value();
203
204        // Classifier
205        ClusterClassifier classifier =
206            new ClusterClassifier(stationCenters,
207                                  entryNameCol.stringValue().trim(),
208                                  stdCol.stringValue().trim());
209
210        // Calculate nearest cluster center by Euclidean distance for each row
211        // in the Dataset
212        // NOTE: If exceptions are encountered, won't output any tokens
213        try{
214            // use map function to identify center ID for each row
215            Dataset<Row> outDf = inDf.map(classifier, rowEncoder);
216
217            // add schema column for kmeans distance-space vector 
218            StructType outSchema =
219                inDf.schema().add(kmeansCol.stringValue(),
220                                  SQLDataTypes.VectorType());
221            outDf = _sqlContext.createDataFrame(outDf.toJavaRDD(),
222                                                outSchema);
223
224            // output results
225            outData.broadcast(new ObjectToken(outDf, Dataset.class));
226
227            //DEBUG
228            /*
229            if(_debugging)
230                outDf.printSchema();
231                */
232            if (debugOutput.numberOfSinks() > 0)
233                debugOutput.broadcast(new StringToken(outDf.first().toString()));
234            
235        } catch (Exception e) {
236            System.err.println("ERROR: KMeansApply failed! " + e.toString());
237        }
238    }
239
240
241    /*
242     * Private helper class for computing k-means distance space vectors
243     * TODO: change class name
244     */     
245    private static class ClusterClassifier implements MapFunction<Row, Row> {
246
247        private Map<String, double[][]> clusterCenters;
248        private String idCol, stdCol;
249        
250        public ClusterClassifier(Map<String, double[][]> centers,
251                                 String id, String std) {
252            clusterCenters = centers;
253            idCol = id;
254            stdCol = std;
255        }
256        
257        @Override
258        public Row call(Row datum) throws Exception {
259
260            EuclideanDistance ed = new EuclideanDistance();
261
262            // first, get array of cluster centers from the map
263            String name;
264            try {
265                // get entry name
266                name = datum.getString(datum.fieldIndex(idCol));
267                name = name.replaceAll("\\s+", "");  // clean up name string
268            } catch(IllegalArgumentException e) {
269                System.err.println("ERROR: Can't find entry column: " + idCol);
270                throw e;
271            }
272
273            // get trained cluster data for current station
274            int numClusters, vectorDim;
275            double[][] centers;
276            try {
277                centers = clusterCenters.get(name);
278                // get number of clusters per entry, and cluster dimension 
279                numClusters = centers.length;
280                vectorDim = centers[0].length;
281            } catch(NullPointerException e) {
282                System.err.println("ERROR: No cluster center info for " +
283                                   "station: " + name);
284                throw e;
285            }
286            
287            // get standardized vector as double[] from input row
288            DenseVector stdVector;
289            try{
290                stdVector = (DenseVector)datum.get(datum.fieldIndex(stdCol));
291            } catch(IllegalArgumentException e) {
292                System.err.println("ERROR: Can't find stdCol: " + stdCol);
293                throw e;
294            }
295            double[] d = stdVector.toArray();
296
297            // calculate ED to all centers, assemble distance-space vector
298            double[] distArray = new double[numClusters];
299            for (int i = 0; i < numClusters; i++) {
300                double[] c = centers[i];
301                distArray[i] = ed.compute(c, d);
302            }
303            Vector distVector = Vectors.dense(distArray);
304
305            // Assemble the output row by appending the dist vector field
306            Object[] outFields = new Object[datum.size()+1];
307            for(int i = 0; i < datum.size(); i++)
308                outFields[i] = datum.get(i);
309            outFields[datum.size()] = distVector; // dist vector
310            
311            return RowFactory.create(outFields);
312        }
313    }
314
315    /** Dataframe input */
316    public TypedIOPort inData;
317
318    /** Dataframe output w/kmeans distance space vector */ 
319    public TypedIOPort outData;
320
321    /** Path to model file */
322    public StringParameter inFilepath;
323
324    /** Name of entry name column in input dataframe */
325    public StringParameter entryNameCol;
326
327    /** Name of standard vector column */
328    public StringParameter stdCol;
329
330    /** Name to kmeans vector column (to append to dataframe) */ 
331    public StringParameter kmeansCol;
332
333    /** Debug output port */
334    public TypedIOPort debugOutput;
335
336    
337    /* broadcasted cluster centers map */
338    private Broadcast<Map<String, double[][]>> _bcastCenters;
339    
340}
341
342