001/* Apply a trained support vector machine model to a data set.
002 * 
003 * Copyright (c) 2015 The Regents of the University of California.
004 * All rights reserved.
005 *
006 * '$Author: crawl $'
007 * '$Date: 2015-08-31 18:04:42 +0000 (Mon, 31 Aug 2015) $' 
008 * '$Revision: 33837 $'
009 * 
010 * Permission is hereby granted, without written agreement and without
011 * license or royalty fees, to use, copy, modify, and distribute this
012 * software and its documentation for any purpose, provided that the above
013 * copyright notice and the following two paragraphs appear in all copies
014 * of this software.
015 *
016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
020 * SUCH DAMAGE.
021 *
022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
027 * ENHANCEMENTS, OR MODIFICATIONS.
028 *
029 */
030
031package org.kepler.spark.mllib;
032
033import java.util.ArrayList;
034import java.util.List;
035
036import org.apache.spark.api.java.JavaRDD;
037import org.apache.spark.mllib.classification.SVMModel;
038import org.apache.spark.mllib.evaluation.MulticlassMetrics;
039import org.apache.spark.mllib.regression.LabeledPoint;
040import org.kepler.spark.actor.SparkBaseActor;
041
042import ptolemy.actor.TypedIOPort;
043import ptolemy.actor.parameters.PortParameter;
044import ptolemy.data.DoubleToken;
045import ptolemy.data.ObjectToken;
046import ptolemy.data.StringToken;
047import ptolemy.data.type.BaseType;
048import ptolemy.data.type.ObjectType;
049import ptolemy.kernel.CompositeEntity;
050import ptolemy.kernel.util.IllegalActionException;
051import ptolemy.kernel.util.NameDuplicationException;
052import ptolemy.kernel.util.SingletonAttribute;
053import scala.Tuple2;
054
055/** Apply a trained support vector machine model to a data set.
056 * 
057 *  @author Sherman Cheung and Mai H. Nguyen
058 *  @version $Id: SVMApply.java 33837 2015-08-31 18:04:42Z crawl $
059 * 
060 */
061public class SVMApply extends SparkBaseActor {
062
063        public SVMApply(CompositeEntity container, String name)
064                        throws IllegalActionException, NameDuplicationException {
065                super(container, name);
066                
067        modelPath = new PortParameter(this, "modelPath");        
068        modelPath.setTypeEquals(BaseType.STRING);
069        modelPath.getPort().setTypeEquals(BaseType.STRING);
070        new SingletonAttribute(modelPath.getPort(), "_showName");
071        modelPath.setStringMode(true);
072                
073                dataIn = new TypedIOPort(this, "dataIn", true, false);
074                dataIn.setTypeEquals(new ObjectType(JavaRDD.class));
075                new SingletonAttribute(dataIn, "_showName");
076                                
077                error = new TypedIOPort(this, "error", false, true);
078                error.setTypeEquals(BaseType.DOUBLE);
079                new SingletonAttribute(error, "_showname");
080
081        }
082        
083        @Override
084        public void fire() throws IllegalActionException {
085                super.fire();
086        
087        modelPath.update();
088        final String modelPathVal = ((StringToken)modelPath.getToken()).stringValue();
089        final SVMModel model = SVMModel.load(_context.sc(), modelPathVal);
090
091                // final SVMModel model = (SVMModel) ((ObjectToken)modelPath.get(0)).getValue();
092                final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)dataIn.get(0)).getValue();
093                
094                List<LabeledPoint> points = dataset.collect();
095                List<Tuple2<Object, Object>> tuples = new ArrayList<Tuple2<Object, Object>>();
096                
097                for (LabeledPoint p: points) {
098                        Double score = model.predict(p.features());
099                        tuples.add(new Tuple2<Object, Object>(score, p.label()));
100                }
101                
102                JavaRDD<Tuple2<Object, Object>> predictionAndLabel = _context.parallelize(tuples);
103
104            // Compute classification error and send to port
105            MulticlassMetrics mcMetrics = 
106                        new MulticlassMetrics(JavaRDD.toRDD(predictionAndLabel));
107            double errVal = 1.0 - mcMetrics.precision();
108            error.broadcast(new DoubleToken(errVal));
109            
110        }
111
112        /** The SVM Model taken as input */
113        public PortParameter modelPath;
114        
115        /** The JavaRDD of test data taken as input */
116        public TypedIOPort dataIn;
117                
118        /** The area under the receiver operating characteristic(ROC) */
119        public TypedIOPort error;
120}