001/* Train a RandomForest model using Spark MLlib 002 * 003 * Copyright (c) 2015 The Regents of the University of California. 004 * All rights reserved. 005 * 006 * '$Author: crawl $' 007 * '$Date: 2015-11-06 22:03:02 +0000 (Fri, 06 Nov 2015) $' 008 * '$Revision: 34218 $' 009 * 010 * Permission is hereby granted, without written agreement and without 011 * license or royalty fees, to use, copy, modify, and distribute this 012 * software and its documentation for any purpose, provided that the above 013 * copyright notice and the following two paragraphs appear in all copies 014 * of this software. 015 * 016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 020 * SUCH DAMAGE. 021 * 022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 027 * ENHANCEMENTS, OR MODIFICATIONS. 028 * 029 */ 030package org.kepler.spark.mllib; 031 032import java.util.HashMap; 033 034import org.apache.spark.api.java.JavaPairRDD; 035import org.apache.spark.api.java.JavaRDD; 036import org.apache.spark.api.java.function.Function; 037import org.apache.spark.api.java.function.PairFunction; 038import org.apache.spark.mllib.regression.LabeledPoint; 039import org.apache.spark.mllib.tree.RandomForest; 040import org.kepler.spark.actor.SaveableModelActor; 041 042import ptolemy.actor.TypedIOPort; 043import ptolemy.actor.parameters.PortParameter; 044import ptolemy.data.DoubleToken; 045import ptolemy.data.IntToken; 046import ptolemy.data.ObjectToken; 047import ptolemy.data.StringToken; 048import ptolemy.data.type.BaseType; 049import ptolemy.data.type.ObjectType; 050import ptolemy.kernel.CompositeEntity; 051import ptolemy.kernel.util.IllegalActionException; 052import ptolemy.kernel.util.NameDuplicationException; 053import ptolemy.kernel.util.SingletonAttribute; 054import scala.Tuple2; 055 056/** Train a RandomForest model using Spark MLlib. 057 * 058 * @author Mai H. Nguyen, Ankush Agrawal and Elle Nguyen-Khoa 059 * @version $Id: RandomForestModel.java 34218 2015-11-06 22:03:02Z crawl $ 060 * 061 */ 062public class RandomForestModel extends SaveableModelActor { 063 064 public RandomForestModel(CompositeEntity container, String name) 065 throws IllegalActionException, NameDuplicationException { 066 super(container, name); 067 068 data = new TypedIOPort(this, "data", true, false); 069 data.setTypeEquals(new ObjectType(JavaRDD.class)); 070 new SingletonAttribute(data, "_showName"); 071 072 numClasses = new PortParameter(this, "numClasses"); 073 numClasses.setTypeEquals(BaseType.INT); 074 numClasses.getPort().setTypeEquals(BaseType.INT); 075 new SingletonAttribute(numClasses.getPort(), "_showName"); 076 numClasses.setToken(new IntToken(15)); 077 078 numTrees = new PortParameter(this, "numTrees"); 079 numTrees.setTypeEquals(BaseType.INT); 080 numTrees.getPort().setTypeEquals(BaseType.INT); 081 new SingletonAttribute(numTrees.getPort(), "_showName"); 082 numTrees.setToken(new IntToken(100)); 083 084 maxDepth = new PortParameter(this, "maxDepth"); 085 maxDepth.setTypeEquals(BaseType.INT); 086 maxDepth.getPort().setTypeEquals(BaseType.INT); 087 new SingletonAttribute(maxDepth.getPort(), "_showName"); 088 maxDepth.setToken(new IntToken(15)); 089 090 maxBins = new PortParameter(this, "maxBins"); 091 maxBins.setTypeEquals(BaseType.INT); 092 maxBins.getPort().setTypeEquals(BaseType.INT); 093 new SingletonAttribute(maxBins.getPort(), "_showName"); 094 maxBins.setToken(new IntToken(15)); 095 096 modelPath.setToken(new StringToken("RFmodel")); 097 098 error = new TypedIOPort(this, "error", false, true); 099 error.setTypeEquals(BaseType.DOUBLE); 100 new SingletonAttribute(error, "_showName"); 101 } 102 103 @Override 104 public void fire() throws IllegalActionException { 105 106 super.fire(); 107 108 numClasses.update(); 109 final int numClassesVal = ((IntToken)numClasses.getToken()).intValue(); 110 111 numTrees.update(); 112 final int numTreesVal = ((IntToken)numTrees.getToken()).intValue(); 113 114 maxDepth.update(); 115 final int maxDepthVal = ((IntToken)maxDepth.getToken()).intValue(); 116 117 maxBins.update(); 118 final int maxBinsVal = ((IntToken)maxBins.getToken()).intValue(); 119 120 //@SuppressWarnings("unchecked") 121 HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); 122 123 // Read data 124 final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)data.get(0)).getValue(); 125 126 // Train the model. 127 _model = RandomForest.trainClassifier( dataset, numClassesVal, 128 categoricalFeaturesInfo, 129 numTreesVal, "auto", "gini", maxDepthVal, maxBinsVal, 14); 130 131 // Evaluate model on data set and compute error 132 JavaPairRDD<Double,Double> predictionAndLabel = 133 dataset.mapToPair(new PredictionAndLabel((org.apache.spark.mllib.tree.model.RandomForestModel)_model)); 134 Double datasetErr = 1.0 * predictionAndLabel.filter(new ComputeError()).count() / dataset.count(); 135 error.broadcast(new DoubleToken(datasetErr)); 136 } 137 138 /** Ports **/ 139 140 /** Input data: RDD of LabeledPoint */ 141 public TypedIOPort data; 142 143 /** Number of classes to categorize */ 144 public PortParameter numClasses; 145 146 /** Number of trees in ensemble */ 147 public PortParameter numTrees; 148 149 /** Max depth of each tree */ 150 public PortParameter maxDepth; 151 152 /** Max number of bins to use in discretizing continuous-valued features */ 153 public PortParameter maxBins; 154 155 /** Classification error on data set */ 156 public TypedIOPort error; 157 158 /** Class to create RDD with predictions and labels **/ 159 private static class PredictionAndLabel implements PairFunction<LabeledPoint,Double,Double> { 160 private static final long serialVersionUID = -4136732024436215900L; 161 public PredictionAndLabel (org.apache.spark.mllib.tree.model.RandomForestModel model) { 162 _model = model; 163 } 164 165 @Override 166 public Tuple2<Double,Double> call(LabeledPoint p) { 167 return new Tuple2<Double,Double>(_model.predict(p.features()),p.label()); 168 } 169 private final org.apache.spark.mllib.tree.model.RandomForestModel _model; 170 } 171 172 private static class ComputeError implements Function<Tuple2<Double,Double>, Boolean> { 173 private static final long serialVersionUID = 6238046495060057530L; 174 175 @Override 176 public Boolean call(Tuple2<Double, Double> pl) { 177 return !pl._1().equals(pl._2()); 178 } 179 } 180}