001/* Train a RandomForest model using Spark MLlib
002 * 
003 * Copyright (c) 2015 The Regents of the University of California.
004 * All rights reserved.
005 *
006 * '$Author: crawl $'
007 * '$Date: 2015-11-06 22:03:02 +0000 (Fri, 06 Nov 2015) $' 
008 * '$Revision: 34218 $'
009 * 
010 * Permission is hereby granted, without written agreement and without
011 * license or royalty fees, to use, copy, modify, and distribute this
012 * software and its documentation for any purpose, provided that the above
013 * copyright notice and the following two paragraphs appear in all copies
014 * of this software.
015 *
016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY
017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF
020 * SUCH DAMAGE.
021 *
022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE
025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF
026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES,
027 * ENHANCEMENTS, OR MODIFICATIONS.
028 *
029 */
030package org.kepler.spark.mllib;
031
032import java.util.HashMap;
033
034import org.apache.spark.api.java.JavaPairRDD;
035import org.apache.spark.api.java.JavaRDD;
036import org.apache.spark.api.java.function.Function;
037import org.apache.spark.api.java.function.PairFunction;
038import org.apache.spark.mllib.regression.LabeledPoint;
039import org.apache.spark.mllib.tree.RandomForest;
040import org.kepler.spark.actor.SaveableModelActor;
041
042import ptolemy.actor.TypedIOPort;
043import ptolemy.actor.parameters.PortParameter;
044import ptolemy.data.DoubleToken;
045import ptolemy.data.IntToken;
046import ptolemy.data.ObjectToken;
047import ptolemy.data.StringToken;
048import ptolemy.data.type.BaseType;
049import ptolemy.data.type.ObjectType;
050import ptolemy.kernel.CompositeEntity;
051import ptolemy.kernel.util.IllegalActionException;
052import ptolemy.kernel.util.NameDuplicationException;
053import ptolemy.kernel.util.SingletonAttribute;
054import scala.Tuple2;
055
056/** Train a RandomForest model using Spark MLlib.
057 * 
058 *  @author Mai H. Nguyen, Ankush Agrawal and Elle Nguyen-Khoa
059 *  @version $Id: RandomForestModel.java 34218 2015-11-06 22:03:02Z crawl $
060 * 
061 */
062public class RandomForestModel extends SaveableModelActor {
063    
064    public RandomForestModel(CompositeEntity container, String name)
065            throws IllegalActionException, NameDuplicationException {
066        super(container, name);
067        
068        data = new TypedIOPort(this, "data", true, false);
069        data.setTypeEquals(new ObjectType(JavaRDD.class));
070        new SingletonAttribute(data, "_showName");
071        
072        numClasses = new PortParameter(this, "numClasses");
073        numClasses.setTypeEquals(BaseType.INT);
074        numClasses.getPort().setTypeEquals(BaseType.INT);
075        new SingletonAttribute(numClasses.getPort(), "_showName");
076        numClasses.setToken(new IntToken(15));
077        
078        numTrees = new PortParameter(this, "numTrees");
079        numTrees.setTypeEquals(BaseType.INT);
080        numTrees.getPort().setTypeEquals(BaseType.INT);
081        new SingletonAttribute(numTrees.getPort(), "_showName");
082        numTrees.setToken(new IntToken(100));
083
084        maxDepth = new PortParameter(this, "maxDepth");
085        maxDepth.setTypeEquals(BaseType.INT);
086        maxDepth.getPort().setTypeEquals(BaseType.INT);
087        new SingletonAttribute(maxDepth.getPort(), "_showName");
088        maxDepth.setToken(new IntToken(15));
089        
090        maxBins = new PortParameter(this, "maxBins");
091        maxBins.setTypeEquals(BaseType.INT);
092        maxBins.getPort().setTypeEquals(BaseType.INT);
093        new SingletonAttribute(maxBins.getPort(), "_showName");
094        maxBins.setToken(new IntToken(15));
095        
096        modelPath.setToken(new StringToken("RFmodel"));
097        
098        error = new TypedIOPort(this, "error", false, true);
099        error.setTypeEquals(BaseType.DOUBLE);
100        new SingletonAttribute(error, "_showName");
101    }
102    
103    @Override
104    public void fire() throws IllegalActionException {
105
106        super.fire();
107        
108        numClasses.update();
109        final int numClassesVal = ((IntToken)numClasses.getToken()).intValue();
110        
111        numTrees.update();
112        final int numTreesVal = ((IntToken)numTrees.getToken()).intValue();
113
114        maxDepth.update();
115        final int maxDepthVal = ((IntToken)maxDepth.getToken()).intValue();
116
117        maxBins.update();
118        final int maxBinsVal = ((IntToken)maxBins.getToken()).intValue();
119        
120        //@SuppressWarnings("unchecked")
121        HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
122        
123        // Read data
124                final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)data.get(0)).getValue();
125                
126                // Train the model.
127        _model = RandomForest.trainClassifier( dataset, numClassesVal, 
128                                               categoricalFeaturesInfo, 
129                                               numTreesVal, "auto", "gini", maxDepthVal, maxBinsVal, 14);
130
131        // Evaluate model on data set and compute error 
132        JavaPairRDD<Double,Double> predictionAndLabel =
133            dataset.mapToPair(new PredictionAndLabel((org.apache.spark.mllib.tree.model.RandomForestModel)_model));             
134        Double datasetErr = 1.0 * predictionAndLabel.filter(new ComputeError()).count() / dataset.count();
135        error.broadcast(new DoubleToken(datasetErr));        
136    }
137    
138    /** Ports **/
139    
140        /** Input data: RDD of LabeledPoint */
141    public TypedIOPort   data;                           
142    
143    /** Number of classes to categorize */
144    public PortParameter numClasses;                     
145    
146    /** Number of trees in ensemble */
147    public PortParameter numTrees;                      
148    
149    /** Max depth of each tree */
150    public PortParameter maxDepth;                              
151    
152    /** Max number of bins to use in discretizing continuous-valued features */
153    public PortParameter maxBins;                               
154        
155    /** Classification error on data set */
156    public TypedIOPort   error;                                 
157    
158    /** Class to create RDD with predictions and labels **/
159    private static class PredictionAndLabel implements PairFunction<LabeledPoint,Double,Double> {
160                private static final long serialVersionUID = -4136732024436215900L;
161                public PredictionAndLabel (org.apache.spark.mllib.tree.model.RandomForestModel model) {
162                        _model = model;                 
163                }
164
165                @Override
166        public Tuple2<Double,Double> call(LabeledPoint p) {
167                return new Tuple2<Double,Double>(_model.predict(p.features()),p.label());
168        }
169        private final org.apache.spark.mllib.tree.model.RandomForestModel _model;
170    }
171   
172        private static class ComputeError implements Function<Tuple2<Double,Double>, Boolean> {
173                private static final long serialVersionUID = 6238046495060057530L;
174
175                @Override
176                public Boolean call(Tuple2<Double, Double> pl) {
177                    return !pl._1().equals(pl._2());
178        }
179        }    
180}