001/* Train an SVM model in Spark MLlib. 002 * 003 * Copyright (c) 2015 The Regents of the University of California. 004 * All rights reserved. 005 * 006 * '$Author: crawl $' 007 * '$Date: 2015-09-03 18:47:06 +0000 (Thu, 03 Sep 2015) $' 008 * '$Revision: 33860 $' 009 * 010 * Permission is hereby granted, without written agreement and without 011 * license or royalty fees, to use, copy, modify, and distribute this 012 * software and its documentation for any purpose, provided that the above 013 * copyright notice and the following two paragraphs appear in all copies 014 * of this software. 015 * 016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 020 * SUCH DAMAGE. 021 * 022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 027 * ENHANCEMENTS, OR MODIFICATIONS. 028 * 029 */ 030package org.kepler.spark.mllib; 031 032import org.apache.spark.api.java.JavaRDD; 033import org.apache.spark.api.java.function.Function; 034import org.apache.spark.mllib.classification.SVMModel; 035import org.apache.spark.mllib.classification.SVMWithSGD; 036import org.apache.spark.mllib.evaluation.MulticlassMetrics; 037import org.apache.spark.mllib.regression.LabeledPoint; 038import org.kepler.spark.actor.SaveableModelActor; 039 040import ptolemy.actor.TypedIOPort; 041import ptolemy.actor.parameters.PortParameter; 042import ptolemy.data.DoubleToken; 043import ptolemy.data.IntToken; 044import ptolemy.data.ObjectToken; 045import ptolemy.data.StringToken; 046import ptolemy.data.type.BaseType; 047import ptolemy.data.type.ObjectType; 048import ptolemy.kernel.CompositeEntity; 049import ptolemy.kernel.util.IllegalActionException; 050import ptolemy.kernel.util.NameDuplicationException; 051import ptolemy.kernel.util.SingletonAttribute; 052import scala.Tuple2; 053 054/** Train an SVM model in Spark MLlib. 055 * 056 * @author Sherman Cheung and Mai H. Nguyen 057 * @version $Id: SVMKModel.java 33860 2015-09-03 18:47:06Z crawl $ 058 * 059 */ 060public class SVMKModel extends SaveableModelActor { 061 062 public SVMKModel(CompositeEntity container, String name) 063 throws IllegalActionException, NameDuplicationException { 064 super(container, name); 065 066 inData = new TypedIOPort(this, "inData", true, false); 067 inData.setTypeEquals(new ObjectType(JavaRDD.class)); 068 new SingletonAttribute(inData, "_showName"); 069 070 modelPath.setToken(new StringToken("SVMmodel")); 071 072 numIterations = new PortParameter(this, "numIterations"); 073 numIterations.setTypeEquals(BaseType.INT); 074 numIterations.getPort().setTypeEquals(BaseType.INT); 075 new SingletonAttribute(numIterations.getPort(), "_showName"); 076 numIterations.setToken(new IntToken(100)); 077 078 outSVM = new TypedIOPort(this, "outSVM", false, true); 079 outSVM.setTypeEquals(new ObjectType(SVMModel.class)); 080 new SingletonAttribute(outSVM, "_showName"); 081 082 error = new TypedIOPort(this, "error", false, true); 083 error.setTypeEquals(BaseType.DOUBLE); 084 new SingletonAttribute(error, "_showname"); 085 086 } 087 088 @Override 089 public void fire() throws IllegalActionException { 090 super.fire(); 091 092 final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)inData.get(0)).getValue(); 093 094 numIterations.update(); 095 final int iterations = ((IntToken)numIterations.getToken()).intValue(); 096 097 // Train model and send to port. 098 _model = SVMWithSGD.train(dataset.rdd().cache(), iterations); 099 outSVM.broadcast(new ObjectToken(_model, SVMModel.class)); 100 101 // Evaluate trained model on data set 102 JavaRDD<Tuple2<Object,Object>> predictionAndLabel = dataset.map(new PredictionAndLabel((SVMModel)_model)); 103 104 // Compute classification error and send to port 105 MulticlassMetrics mcMetrics = 106 new MulticlassMetrics(JavaRDD.toRDD(predictionAndLabel)); 107 double errVal = 1.0 - mcMetrics.precision(); 108 error.broadcast(new DoubleToken(errVal)); 109 } 110 111 /** Input data: RDD of LabeledPoint. */ 112 public TypedIOPort inData; 113 114 /** The number of iterations to run the training algorithm to build the model */ 115 public PortParameter numIterations; 116 117 /** SVMModel output. */ 118 public TypedIOPort outSVM; 119 120 /** The classification error */ 121 public TypedIOPort error; 122 123 /** Class to create RDD with predictions and labels **/ 124 private static class PredictionAndLabel implements Function<LabeledPoint,Tuple2<Object,Object>> { 125 /** 126 * 127 */ 128 private static final long serialVersionUID = -2627242327996821646L; 129 public PredictionAndLabel (SVMModel mdl) { 130 _model = mdl; 131 } 132 133 @Override 134 public Tuple2<Object,Object> call(LabeledPoint p) { 135 return new Tuple2<Object,Object>(_model.predict(p.features()),p.label()); 136 } 137 private final SVMModel _model; 138 } 139}