forked from opensci/piflow
add GBT stops
This commit is contained in:
parent
fc83efec28
commit
9b621d4db1
|
@ -0,0 +1,45 @@
|
|||
{
|
||||
"flow":{
|
||||
"name":"test",
|
||||
"uuid":"1234",
|
||||
"stops":[
|
||||
{
|
||||
"uuid":"0000",
|
||||
"name":"GBTTraining",
|
||||
"bundle":"cn.piflow.bundle.ml_classification.GBTTraining",
|
||||
"properties":{
|
||||
"training_data_path":"hdfs://10.0.86.89:9000/xx/watermellonDataset.txt",
|
||||
"model_save_path":"hdfs://10.0.86.89:9000/xx/naivebayes/gbt.model",
|
||||
"maxBins":"20",
|
||||
"maxDepth":"10",
|
||||
"minInfoGain":"0.1",
|
||||
"minInstancesPerNode":"2",
|
||||
"impurity":"entropy",
|
||||
"subSamplingRate":"0.6",
|
||||
"stepSize":"0.2",
|
||||
"lossType":"logistic"
|
||||
}
|
||||
|
||||
},
|
||||
{
|
||||
"uuid":"1111",
|
||||
"name":"GBTPrediction",
|
||||
"bundle":"cn.piflow.bundle.ml_classification.GBTPrediction",
|
||||
"properties":{
|
||||
"test_data_path":"hdfs://10.0.86.89:9000/xx/watermellonDataset.txt",
|
||||
"model_path":"hdfs://10.0.86.89:9000/xx/naivebayes/gbt.model"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
],
|
||||
"paths":[
|
||||
{
|
||||
"from":"GBTTraining",
|
||||
"outport":"",
|
||||
"inport":"",
|
||||
"to":"GBTPrediction"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package cn.piflow.bundle.ml_classification
|
||||
|
||||
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
|
||||
import cn.piflow.conf.{ConfigurableStop, PortEnum, StopGroupEnum}
|
||||
import cn.piflow.conf.bean.PropertyDescriptor
|
||||
import cn.piflow.conf.util.MapUtil
|
||||
import org.apache.spark.ml.classification.GBTClassificationModel
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
class GBTPrediction extends ConfigurableStop{
|
||||
val authorEmail: String = "xiaoxiao@cnic.cn"
|
||||
val description: String = "Make use of a exist GBT Model to predict."
|
||||
val inportList: List[String] = List(PortEnum.NonePort.toString)
|
||||
val outportList: List[String] = List(PortEnum.DefaultPort.toString)
|
||||
var test_data_path:String =_
|
||||
var model_path:String=_
|
||||
|
||||
|
||||
def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {
|
||||
val spark = pec.get[SparkSession]()
|
||||
//load data stored in libsvm format as a dataframe
|
||||
val data=spark.read.format("libsvm").load(test_data_path)
|
||||
//data.show()
|
||||
|
||||
//load model
|
||||
val model=GBTClassificationModel.load(model_path)
|
||||
|
||||
val predictions=model.transform(data)
|
||||
predictions.show()
|
||||
out.write(predictions)
|
||||
|
||||
}
|
||||
|
||||
def initialize(ctx: ProcessContext): Unit = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
def setProperties(map: Map[String, Any]): Unit = {
|
||||
test_data_path=MapUtil.get(map,key="test_data_path").asInstanceOf[String]
|
||||
model_path=MapUtil.get(map,key="model_path").asInstanceOf[String]
|
||||
}
|
||||
|
||||
override def getPropertyDescriptor(): List[PropertyDescriptor] = {
|
||||
var descriptor : List[PropertyDescriptor] = List()
|
||||
val test_data_path = new PropertyDescriptor().name("test_data_path").displayName("TEST_DATA_PATH").defaultValue("").required(true)
|
||||
val model_path = new PropertyDescriptor().name("model_path").displayName("MODEL_PATH").defaultValue("").required(true)
|
||||
descriptor = test_data_path :: descriptor
|
||||
descriptor = model_path :: descriptor
|
||||
descriptor
|
||||
}
|
||||
|
||||
override def getIcon(): Array[Byte] = ???
|
||||
|
||||
override def getGroup(): List[String] = {
|
||||
List(StopGroupEnum.MLGroup.toString)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
package cn.piflow.bundle.ml_classification
|
||||
|
||||
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
|
||||
import cn.piflow.conf.{ConfigurableStop, PortEnum, StopGroupEnum}
|
||||
import cn.piflow.conf.bean.PropertyDescriptor
|
||||
import cn.piflow.conf.util.MapUtil
|
||||
import org.apache.spark.ml.classification.GBTClassifier
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
class GBTTraining extends ConfigurableStop{
|
||||
val authorEmail: String = "xiaoxiao@cnic.cn"
|
||||
val description: String = "Training a GBT Model."
|
||||
val inportList: List[String] = List(PortEnum.NonePort.toString)
|
||||
val outportList: List[String] = List(PortEnum.DefaultPort.toString)
|
||||
var training_data_path:String =_
|
||||
var model_save_path:String=_
|
||||
var maxBins:String=_
|
||||
var maxDepth:String=_
|
||||
var minInfoGain:String=_
|
||||
var minInstancesPerNode:String=_
|
||||
var impurity:String=_
|
||||
var subSamplingRate:String=_
|
||||
var lossType:String=_
|
||||
var stepSize:String=_
|
||||
|
||||
def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {
|
||||
val spark = pec.get[SparkSession]()
|
||||
|
||||
//load data stored in libsvm format as a dataframe
|
||||
val data=spark.read.format("libsvm").load(training_data_path)
|
||||
|
||||
//Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity.Must be >= 2 and >= number of categories in any categorical feature.
|
||||
var maxBinsValue:Int=40
|
||||
if(maxBins!=""){
|
||||
maxBinsValue=maxBins.toInt
|
||||
}
|
||||
|
||||
//Maximum depth of the tree (>= 0).The maximum is 30.
|
||||
var maxDepthValue:Int=30
|
||||
if(maxDepth!=""){
|
||||
maxDepthValue=maxDepth.toInt
|
||||
}
|
||||
|
||||
//Minimum information gain for a split to be considered at a tree node.
|
||||
var minInfoGainValue:Double=0.2
|
||||
if(minInfoGain!=""){
|
||||
minInfoGainValue=minInfoGain.toDouble
|
||||
}
|
||||
|
||||
//Minimum number of instances each child must have after split.
|
||||
var minInstancesPerNodeValue:Int=3
|
||||
if(minInstancesPerNode!=""){
|
||||
minInstancesPerNodeValue=minInstancesPerNode.toInt
|
||||
}
|
||||
|
||||
//Param for the name of family which is a description of the label distribution to be used in the model
|
||||
var impurityValue="gini"
|
||||
if(impurity!=""){
|
||||
impurityValue=impurity
|
||||
}
|
||||
|
||||
var subSamplingRateValue:Double=0.6
|
||||
if(subSamplingRate!=""){
|
||||
subSamplingRateValue=subSamplingRate.toDouble
|
||||
}
|
||||
|
||||
var lossTypeValue="logistic"
|
||||
if(lossType!=""){
|
||||
lossTypeValue=lossType
|
||||
}
|
||||
|
||||
var stepSizeValue:Double=0.1
|
||||
if(stepSize!=""){
|
||||
stepSizeValue=stepSize.toDouble
|
||||
}
|
||||
|
||||
//training a GBT model
|
||||
val model=new GBTClassifier()
|
||||
.setMaxBins(maxBinsValue)
|
||||
.setMaxDepth(maxDepthValue)
|
||||
.setMinInfoGain(minInfoGainValue)
|
||||
.setMinInstancesPerNode(minInstancesPerNodeValue)
|
||||
.setImpurity(impurityValue)
|
||||
.setLossType(lossTypeValue)
|
||||
.setSubsamplingRate(subSamplingRateValue)
|
||||
.setStepSize(stepSizeValue)
|
||||
.fit(data)
|
||||
|
||||
//model persistence
|
||||
model.save(model_save_path)
|
||||
|
||||
import spark.implicits._
|
||||
val dfOut=Seq(model_save_path).toDF
|
||||
dfOut.show()
|
||||
out.write(dfOut)
|
||||
|
||||
}
|
||||
|
||||
def initialize(ctx: ProcessContext): Unit = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
def setProperties(map: Map[String, Any]): Unit = {
|
||||
training_data_path=MapUtil.get(map,key="training_data_path").asInstanceOf[String]
|
||||
model_save_path=MapUtil.get(map,key="model_save_path").asInstanceOf[String]
|
||||
maxBins=MapUtil.get(map,key="maxBins").asInstanceOf[String]
|
||||
maxDepth=MapUtil.get(map,key="maxDepth").asInstanceOf[String]
|
||||
minInfoGain=MapUtil.get(map,key="minInfoGain").asInstanceOf[String]
|
||||
minInstancesPerNode=MapUtil.get(map,key="minInstancesPerNode").asInstanceOf[String]
|
||||
impurity=MapUtil.get(map,key="impurity").asInstanceOf[String]
|
||||
subSamplingRate=MapUtil.get(map,key="subSamplingRate").asInstanceOf[String]
|
||||
lossType=MapUtil.get(map,key="lossType").asInstanceOf[String]
|
||||
stepSize=MapUtil.get(map,key="stepSize").asInstanceOf[String]
|
||||
}
|
||||
|
||||
override def getPropertyDescriptor(): List[PropertyDescriptor] = {
|
||||
var descriptor : List[PropertyDescriptor] = List()
|
||||
val training_data_path = new PropertyDescriptor().name("training_data_path").displayName("TRAINING_DATA_PATH").defaultValue("").required(true)
|
||||
val model_save_path = new PropertyDescriptor().name("model_save_path").displayName("MODEL_SAVE_PATH").description("").defaultValue("").required(true)
|
||||
val maxBins=new PropertyDescriptor().name("maxBins").displayName("MAX_BINS").description("ddd").defaultValue("").required(false)
|
||||
val maxDepth=new PropertyDescriptor().name("maxDepth").displayName("MAX_DEPTH").description("ddd").defaultValue("").required(false)
|
||||
val minInfoGain=new PropertyDescriptor().name("minInfoGain").displayName("MIN_INFO_GAIN").description("ddd").defaultValue("").required(false)
|
||||
val minInstancesPerNode=new PropertyDescriptor().name("minInstancesPerNode").displayName("MIN_INSTANCES_PER_NODE").description("ddd").defaultValue("").required(false)
|
||||
val impurity=new PropertyDescriptor().name("impurity").displayName("IMPURITY").description("Criterion used for information gain calculation (case-insensitive). Supported: \"entropy\" and \"gini\". (default = gini)").defaultValue("").required(false)
|
||||
val subSamplingRate=new PropertyDescriptor().name("subSamplingRate").displayName("SUB_SAMPLING_RATE").description("ddd").defaultValue("").required(false)
|
||||
val lossType=new PropertyDescriptor().name("lossType").displayName("LOSS_TYPE").description("ddd").defaultValue("").required(false)
|
||||
val stepSize=new PropertyDescriptor().name("stepSize").displayName("STEP_SIZE").description("ddd").defaultValue("").required(false)
|
||||
descriptor = training_data_path :: descriptor
|
||||
descriptor = model_save_path :: descriptor
|
||||
descriptor = maxBins :: descriptor
|
||||
descriptor = maxDepth :: descriptor
|
||||
descriptor = minInfoGain :: descriptor
|
||||
descriptor = minInstancesPerNode :: descriptor
|
||||
descriptor = impurity :: descriptor
|
||||
descriptor = subSamplingRate::descriptor
|
||||
descriptor = lossType :: descriptor
|
||||
descriptor = stepSize :: descriptor
|
||||
descriptor
|
||||
}
|
||||
|
||||
override def getIcon(): Array[Byte] = ???
|
||||
|
||||
override def getGroup(): List[String] = {
|
||||
List(StopGroupEnum.MLGroup.toString)
|
||||
}
|
||||
|
||||
}
|
|
@ -14,7 +14,7 @@ class FlowTest_XX {
|
|||
def testFlow(): Unit ={
|
||||
|
||||
//parse flow json
|
||||
val file = "src/main/resources/randomforest.json"
|
||||
val file = "src/main/resources/gbt.json"
|
||||
val flowJsonStr = FileUtil.fileReader(file)
|
||||
val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
|
||||
println(map)
|
||||
|
@ -49,7 +49,7 @@ class FlowTest_XX {
|
|||
def testFlow2json() = {
|
||||
|
||||
//parse flow json
|
||||
val file = "src/main/resources/randomforest.json"
|
||||
val file = "src/main/resources/gbt.json"
|
||||
val flowJsonStr = FileUtil.fileReader(file)
|
||||
val map = OptionUtil.getAny(JSON.parseFull(flowJsonStr)).asInstanceOf[Map[String, Any]]
|
||||
|
||||
|
|
Loading…
Reference in New Issue