001/* Train an SVM model in Spark MLlib.
002 * 
003 * Copyright (c) 2015 The Regents of the University of California.
004 * All rights reserved.
005 *
006 * '$Author: crawl $'
007 * '$Date: 2015-09-03 18:47:06 +0000 (Thu, 03 Sep 2015) $' 
008 * '$Revision: 33860 $'
009 * 
010 * Permission is hereby granted, without written agreement and without
011 * license or royalty fees, to use, copy, modify, and distribute this
012 * software and its documentation for any purpose, provided that the above
013 * copyright notice and the following two paragraphs appear in all copies
014 * of this software.
015 *
016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
020 * SUCH DAMAGE.
021 *
022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
027 * ENHANCEMENTS, OR MODIFICATIONS.
028 *
029 */
030package org.kepler.spark.mllib;
031
032import org.apache.spark.api.java.JavaRDD;
033import org.apache.spark.api.java.function.Function;
034import org.apache.spark.mllib.classification.SVMModel;
035import org.apache.spark.mllib.classification.SVMWithSGD;
036import org.apache.spark.mllib.evaluation.MulticlassMetrics;
037import org.apache.spark.mllib.regression.LabeledPoint;
038import org.kepler.spark.actor.SaveableModelActor;
039
040import ptolemy.actor.TypedIOPort;
041import ptolemy.actor.parameters.PortParameter;
042import ptolemy.data.DoubleToken;
043import ptolemy.data.IntToken;
044import ptolemy.data.ObjectToken;
045import ptolemy.data.StringToken;
046import ptolemy.data.type.BaseType;
047import ptolemy.data.type.ObjectType;
048import ptolemy.kernel.CompositeEntity;
049import ptolemy.kernel.util.IllegalActionException;
050import ptolemy.kernel.util.NameDuplicationException;
051import ptolemy.kernel.util.SingletonAttribute;
052import scala.Tuple2;
053
054/** Train an SVM model in Spark MLlib.
055 * 
056 *  @author Sherman Cheung and Mai H. Nguyen
057 *  @version $Id: SVMKModel.java 33860 2015-09-03 18:47:06Z crawl $
058 * 
059 */
060public class SVMKModel extends SaveableModelActor {
061
062        public SVMKModel(CompositeEntity container, String name) 
063                        throws IllegalActionException, NameDuplicationException {
064                super(container, name);
065                
066                inData = new TypedIOPort(this, "inData", true, false);
067                inData.setTypeEquals(new ObjectType(JavaRDD.class));
068                new SingletonAttribute(inData, "_showName");
069                
070        modelPath.setToken(new StringToken("SVMmodel"));
071
072                numIterations = new PortParameter(this, "numIterations");
073                numIterations.setTypeEquals(BaseType.INT);
074                numIterations.getPort().setTypeEquals(BaseType.INT);
075                new SingletonAttribute(numIterations.getPort(), "_showName");
076                numIterations.setToken(new IntToken(100));
077                                
078                outSVM = new TypedIOPort(this, "outSVM", false, true);
079                outSVM.setTypeEquals(new ObjectType(SVMModel.class));
080                new SingletonAttribute(outSVM, "_showName");
081                
082                error = new TypedIOPort(this, "error", false, true);
083                error.setTypeEquals(BaseType.DOUBLE);
084                new SingletonAttribute(error, "_showname");
085                
086        }
087        
088        @Override
089        public void fire() throws IllegalActionException {
090                super.fire();
091                
092                final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)inData.get(0)).getValue();
093                
094                numIterations.update();
095                final int iterations = ((IntToken)numIterations.getToken()).intValue();
096                                
097                // Train model and send to port.
098                _model = SVMWithSGD.train(dataset.rdd().cache(), iterations);
099                outSVM.broadcast(new ObjectToken(_model, SVMModel.class));
100
101       // Evaluate trained model on data set 
102        JavaRDD<Tuple2<Object,Object>> predictionAndLabel = dataset.map(new PredictionAndLabel((SVMModel)_model));             
103
104            // Compute classification error and send to port
105            MulticlassMetrics mcMetrics = 
106                        new MulticlassMetrics(JavaRDD.toRDD(predictionAndLabel));
107            double errVal = 1.0 - mcMetrics.precision();
108            error.broadcast(new DoubleToken(errVal));                   
109        }
110        
111        /** Input data: RDD of LabeledPoint. */
112        public TypedIOPort inData;
113        
114        /** The number of iterations to run the training algorithm to build the model */
115        public PortParameter numIterations;
116
117        /** SVMModel output. */
118        public TypedIOPort outSVM;
119        
120        /** The classification error */
121        public TypedIOPort error;
122        
123    /** Class to create RDD with predictions and labels **/
124    private static class PredictionAndLabel implements Function<LabeledPoint,Tuple2<Object,Object>> {
125                /**
126                 * 
127                 */
128                private static final long serialVersionUID = -2627242327996821646L;
129                public PredictionAndLabel (SVMModel mdl) {
130                _model = mdl;                   
131                }
132
133                @Override
134        public Tuple2<Object,Object> call(LabeledPoint p) {
135                return new Tuple2<Object,Object>(_model.predict(p.features()),p.label());
136        }
137        private final SVMModel _model;
138    }    
139}