From b878944846d761cf39e1826bb11105f4ea982d94 Mon Sep 17 00:00:00 2001 From: xiaoxiao Date: Wed, 17 Oct 2018 15:47:25 +0800 Subject: [PATCH] add decision tree classification stops --- .../src/main/resources/decisiontree.json | 42 +++++++ .../DecisionTreePrediction.scala | 58 +++++++++ .../DecisionTreeTraining.scala | 119 ++++++++++++++++++ .../scala/cn/piflow/bundle/FlowTest_XX.scala | 6 +- 4 files changed, 222 insertions(+), 3 deletions(-) create mode 100644 piflow-bundle/src/main/resources/decisiontree.json create mode 100644 piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreePrediction.scala create mode 100644 piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreeTraining.scala diff --git a/piflow-bundle/src/main/resources/decisiontree.json b/piflow-bundle/src/main/resources/decisiontree.json new file mode 100644 index 0000000..9a0ea05 --- /dev/null +++ b/piflow-bundle/src/main/resources/decisiontree.json @@ -0,0 +1,42 @@ +{ + "flow":{ + "name":"test", + "uuid":"1234", + "stops":[ + { + "uuid":"0000", + "name":"DecisionTreeTraining", + "bundle":"cn.piflow.bundle.ml_classification.DecisionTreeTraining", + "properties":{ + "training_data_path":"hdfs://10.0.86.89:9000/xx/watermellonDataset.txt", + "model_save_path":"hdfs://10.0.86.89:9000/xx/naivebayes/dt.model", + "maxBins":"20", + "maxDepth":"10", + "minInfoGain":"0.1", + "minInstancesPerNode":"2", + "impurity":"entropy" + } + + }, + { + "uuid":"1111", + "name":"DecisionTreePrediction", + "bundle":"cn.piflow.bundle.ml_classification.DecisionTreePrediction", + "properties":{ + "test_data_path":"hdfs://10.0.86.89:9000/xx/watermellonDataset.txt", + "model_path":"hdfs://10.0.86.89:9000/xx/naivebayes/dt.model" + } + + } + + ], + "paths":[ + { + "from":"DecisionTreeTraining", + "outport":"", + "inport":"", + "to":"DecisionTreePrediction" + } + ] + } +} \ No newline at end of file diff --git a/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreePrediction.scala b/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreePrediction.scala new file mode 100644 index 0000000..72d9ef1 --- /dev/null +++ b/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreePrediction.scala @@ -0,0 +1,58 @@ +package cn.piflow.bundle.ml_classification + +import cn.piflow.conf.bean.PropertyDescriptor +import cn.piflow.conf.util.MapUtil +import cn.piflow.conf.{ConfigurableStop, StopGroupEnum} +import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext} +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.sql.SparkSession + +class DecisionTreePrediction extends ConfigurableStop{ + val authorEmail: String = "xiaoxiao@cnic.cn" + val description: String = "Make use of a exist DecisionTreeModel to predict." + val inportCount: Int = 1 + val outportCount: Int = 0 + var test_data_path:String =_ + var model_path:String=_ + + + def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = { + val spark = pec.get[SparkSession]() + //load data stored in libsvm format as a dataframe + val data=spark.read.format("libsvm").load(test_data_path) + //data.show() + + //load model + val model=DecisionTreeClassificationModel.load(model_path) + + val predictions=model.transform(data) + predictions.show() + out.write(predictions) + + } + + def initialize(ctx: ProcessContext): Unit = { + + } + + + def setProperties(map: Map[String, Any]): Unit = { + test_data_path=MapUtil.get(map,key="test_data_path").asInstanceOf[String] + model_path=MapUtil.get(map,key="model_path").asInstanceOf[String] + } + + override def getPropertyDescriptor(): List[PropertyDescriptor] = { + var descriptor : List[PropertyDescriptor] = List() + val test_data_path = new PropertyDescriptor().name("test_data_path").displayName("TEST_DATA_PATH").defaultValue("").required(true) + val model_path = new PropertyDescriptor().name("model_path").displayName("MODEL_PATH").defaultValue("").required(true) + descriptor = test_data_path :: descriptor + descriptor = model_path :: descriptor + descriptor + } + + override def getIcon(): Array[Byte] = ??? + + override def getGroup(): List[String] = { + List(StopGroupEnum.MLGroup.toString) + } +} diff --git a/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreeTraining.scala b/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreeTraining.scala new file mode 100644 index 0000000..3273f27 --- /dev/null +++ b/piflow-bundle/src/main/scala/cn/piflow/bundle/ml_classification/DecisionTreeTraining.scala @@ -0,0 +1,119 @@ +package cn.piflow.bundle.ml_classification + +import cn.piflow.conf.bean.PropertyDescriptor +import cn.piflow.conf.util.MapUtil +import cn.piflow.conf.{ConfigurableStop, StopGroupEnum} +import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext} +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.sql.SparkSession + +class DecisionTreeTraining extends ConfigurableStop{ + val authorEmail: String = "xiaoxiao@cnic.cn" + val description: String = "Training a DecisionTreeModel." + val inportCount: Int = 1 + val outportCount: Int = 0 + var training_data_path:String =_ + var model_save_path:String=_ + var maxBins:String=_ + var maxDepth:String=_ + var minInfoGain:String=_ + var minInstancesPerNode:String=_ + var impurity:String=_ + + def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = { + val spark = pec.get[SparkSession]() + + //load data stored in libsvm format as a dataframe + val data=spark.read.format("libsvm").load(training_data_path) + + //Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity.Must be >= 2 and >= number of categories in any categorical feature. + var maxBinsValue:Int=40 + if(maxBins!=""){ + maxBinsValue=maxBins.toInt + } + + //Maximum depth of the tree (>= 0).The maximum is 30. + var maxDepthValue:Int=30 + if(maxDepth!=""){ + maxDepthValue=maxDepth.toInt + } + + //Minimum information gain for a split to be considered at a tree node. + var minInfoGainValue:Double=0.2 + if(minInfoGain!=""){ + minInfoGainValue=minInfoGain.toDouble + } + + //Minimum number of instances each child must have after split. + var minInstancesPerNodeValue:Int=3 + if(minInstancesPerNode!=""){ + minInstancesPerNodeValue=minInstancesPerNode.toInt + } + + //Param for the name of family which is a description of the label distribution to be used in the model + var impurityValue="gini" + if(impurity!=""){ + impurityValue=impurity + } + + //training a Logistic Regression model + val model=new DecisionTreeClassifier() + .setMaxBins(maxBinsValue) + .setMaxDepth(maxDepthValue) + .setMinInfoGain(minInfoGainValue) + .setMinInstancesPerNode(minInstancesPerNodeValue) + .setImpurity(impurityValue) + .fit(data) + + //model persistence + model.save(model_save_path) + + import spark.implicits._ + val dfOut=Seq(model_save_path).toDF + dfOut.show() + out.write(dfOut) + + } + + def initialize(ctx: ProcessContext): Unit = { + + } + + + def setProperties(map: Map[String, Any]): Unit = { + training_data_path=MapUtil.get(map,key="training_data_path").asInstanceOf[String] + model_save_path=MapUtil.get(map,key="model_save_path").asInstanceOf[String] + maxBins=MapUtil.get(map,key="maxBins").asInstanceOf[String] + maxDepth=MapUtil.get(map,key="maxDepth").asInstanceOf[String] + minInfoGain=MapUtil.get(map,key="minInfoGain").asInstanceOf[String] + minInstancesPerNode=MapUtil.get(map,key="minInstancesPerNode").asInstanceOf[String] + impurity=MapUtil.get(map,key="impurity").asInstanceOf[String] + + } + + override def getPropertyDescriptor(): List[PropertyDescriptor] = { + var descriptor : List[PropertyDescriptor] = List() + val training_data_path = new PropertyDescriptor().name("training_data_path").displayName("TRAINING_DATA_PATH").defaultValue("").required(true) + val model_save_path = new PropertyDescriptor().name("model_save_path").displayName("MODEL_SAVE_PATH").description("ddd").defaultValue("").required(true) + val maxBins=new PropertyDescriptor().name("maxBins").displayName("MAX_BINS").description("ddd").defaultValue("").required(true) + val maxDepth=new PropertyDescriptor().name("maxDepth").displayName("MAX_DEPTH").description("ddd").defaultValue("").required(true) + val minInfoGain=new PropertyDescriptor().name("minInfoGain").displayName("MIN_INFO_GAIN").description("ddd").defaultValue("").required(true) + val minInstancesPerNode=new PropertyDescriptor().name("minInstancesPerNode").displayName("MIN_INSTANCES_PER_NODE").description("ddd").defaultValue("").required(true) + val impurity=new PropertyDescriptor().name("impurity").displayName("IMPURITY").description("Criterion used for information gain calculation (case-insensitive). Supported: \"entropy\" and \"gini\". (default = gini)").defaultValue("").required(true) + descriptor = training_data_path :: descriptor + descriptor = model_save_path :: descriptor + descriptor = maxBins :: descriptor + descriptor = maxDepth :: descriptor + descriptor = minInfoGain :: descriptor + descriptor = minInstancesPerNode :: descriptor + descriptor = impurity :: descriptor + descriptor + } + + override def getIcon(): Array[Byte] = ??? + + override def getGroup(): List[String] = { + List(StopGroupEnum.MLGroup.toString) + } + +} diff --git a/piflow-bundle/src/test/scala/cn/piflow/bundle/FlowTest_XX.scala b/piflow-bundle/src/test/scala/cn/piflow/bundle/FlowTest_XX.scala index e3899b2..6a819ec 100644 --- a/piflow-bundle/src/test/scala/cn/piflow/bundle/FlowTest_XX.scala +++ b/piflow-bundle/src/test/scala/cn/piflow/bundle/FlowTest_XX.scala @@ -14,7 +14,7 @@ class FlowTest_XX { def testFlow(): Unit ={ //parse flow json - val file = "src/main/resources/logistic.json" + val file = "src/main/resources/decisiontree.json" val flowJsonStr = FileUtil.fileReader(file) val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]] println(map) @@ -30,7 +30,7 @@ class FlowTest_XX { .config("spark.driver.memory", "1g") .config("spark.executor.memory", "2g") .config("spark.cores.max", "2") - .config("spark.jars","/root/xx/piflow/out/artifacts/piflow_jar/piflow_jar.jar") + .config("spark.jars","/home/xx/piflow/out/artifacts/piflow_jar/piflow_jar.jar") .enableHiveSupport() .getOrCreate() @@ -49,7 +49,7 @@ class FlowTest_XX { def testFlow2json() = { //parse flow json - val file = "src/main/resources/logistic.json" + val file = "src/main/resources/decisiontree.json" val flowJsonStr = FileUtil.fileReader(file) val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]