001/* Apply a trained support vector machine model to a data set. 002 * 003 * Copyright (c) 2015 The Regents of the University of California. 004 * All rights reserved. 005 * 006 * '$Author: crawl $' 007 * '$Date: 2015-08-31 18:04:42 +0000 (Mon, 31 Aug 2015) $' 008 * '$Revision: 33837 $' 009 * 010 * Permission is hereby granted, without written agreement and without 011 * license or royalty fees, to use, copy, modify, and distribute this 012 * software and its documentation for any purpose, provided that the above 013 * copyright notice and the following two paragraphs appear in all copies 014 * of this software. 015 * 016 * IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY 017 * FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES 018 * ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 019 * THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF 020 * SUCH DAMAGE. 021 * 022 * THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, 023 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 024 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE 025 * PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF 026 * CALIFORNIA HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, 027 * ENHANCEMENTS, OR MODIFICATIONS. 028 * 029 */ 030 031package org.kepler.spark.mllib; 032 033import java.util.ArrayList; 034import java.util.List; 035 036import org.apache.spark.api.java.JavaRDD; 037import org.apache.spark.mllib.classification.SVMModel; 038import org.apache.spark.mllib.evaluation.MulticlassMetrics; 039import org.apache.spark.mllib.regression.LabeledPoint; 040import org.kepler.spark.actor.SparkBaseActor; 041 042import ptolemy.actor.TypedIOPort; 043import ptolemy.actor.parameters.PortParameter; 044import ptolemy.data.DoubleToken; 045import ptolemy.data.ObjectToken; 046import ptolemy.data.StringToken; 047import ptolemy.data.type.BaseType; 048import ptolemy.data.type.ObjectType; 049import ptolemy.kernel.CompositeEntity; 050import ptolemy.kernel.util.IllegalActionException; 051import ptolemy.kernel.util.NameDuplicationException; 052import ptolemy.kernel.util.SingletonAttribute; 053import scala.Tuple2; 054 055/** Apply a trained support vector machine model to a data set. 056 * 057 * @author Sherman Cheung and Mai H. Nguyen 058 * @version $Id: SVMApply.java 33837 2015-08-31 18:04:42Z crawl $ 059 * 060 */ 061public class SVMApply extends SparkBaseActor { 062 063 public SVMApply(CompositeEntity container, String name) 064 throws IllegalActionException, NameDuplicationException { 065 super(container, name); 066 067 modelPath = new PortParameter(this, "modelPath"); 068 modelPath.setTypeEquals(BaseType.STRING); 069 modelPath.getPort().setTypeEquals(BaseType.STRING); 070 new SingletonAttribute(modelPath.getPort(), "_showName"); 071 modelPath.setStringMode(true); 072 073 dataIn = new TypedIOPort(this, "dataIn", true, false); 074 dataIn.setTypeEquals(new ObjectType(JavaRDD.class)); 075 new SingletonAttribute(dataIn, "_showName"); 076 077 error = new TypedIOPort(this, "error", false, true); 078 error.setTypeEquals(BaseType.DOUBLE); 079 new SingletonAttribute(error, "_showname"); 080 081 } 082 083 @Override 084 public void fire() throws IllegalActionException { 085 super.fire(); 086 087 modelPath.update(); 088 final String modelPathVal = ((StringToken)modelPath.getToken()).stringValue(); 089 final SVMModel model = SVMModel.load(_context.sc(), modelPathVal); 090 091 // final SVMModel model = (SVMModel) ((ObjectToken)modelPath.get(0)).getValue(); 092 final JavaRDD<LabeledPoint> dataset = (JavaRDD<LabeledPoint>) ((ObjectToken)dataIn.get(0)).getValue(); 093 094 List<LabeledPoint> points = dataset.collect(); 095 List<Tuple2<Object, Object>> tuples = new ArrayList<Tuple2<Object, Object>>(); 096 097 for (LabeledPoint p: points) { 098 Double score = model.predict(p.features()); 099 tuples.add(new Tuple2<Object, Object>(score, p.label())); 100 } 101 102 JavaRDD<Tuple2<Object, Object>> predictionAndLabel = _context.parallelize(tuples); 103 104 // Compute classification error and send to port 105 MulticlassMetrics mcMetrics = 106 new MulticlassMetrics(JavaRDD.toRDD(predictionAndLabel)); 107 double errVal = 1.0 - mcMetrics.precision(); 108 error.broadcast(new DoubleToken(errVal)); 109 110 } 111 112 /** The SVM Model taken as input */ 113 public PortParameter modelPath; 114 115 /** The JavaRDD of test data taken as input */ 116 public TypedIOPort dataIn; 117 118 /** The area under the receiver operating characteristic(ROC) */ 119 public TypedIOPort error; 120}