From d734a6ec08804845fa8ca0d171d111421503a8b0 Mon Sep 17 00:00:00 2001 From: "bluejoe2008@gmail.com" Date: Tue, 26 Jun 2018 15:40:24 +0800 Subject: [PATCH] script function accepts 1+ arguments --- piflow-core/src/main/scala/lib/etl.scala | 4 ++-- .../src/main/scala/util/scriptengine.scala | 24 ++++++++++--------- piflow-core/src/test/scala/FlowTest.scala | 1 - 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/piflow-core/src/main/scala/lib/etl.scala b/piflow-core/src/main/scala/lib/etl.scala index c1251b1..5f081fc 100644 --- a/piflow-core/src/main/scala/lib/etl.scala +++ b/piflow-core/src/main/scala/lib/etl.scala @@ -52,7 +52,7 @@ class DoMap(func: FunctionLogic, targetSchema: StructType = null) extends Proces } }; - val output = input.map(func.perform(_).asInstanceOf[Row])(encoder); + val output = input.map(x => func.perform(Seq(x)).asInstanceOf[Row])(encoder); out.write(output); } } @@ -74,7 +74,7 @@ class DoFlatMap(func: FunctionLogic, targetSchema: StructType = null) extends Pr }; val output = data.flatMap(x => - JavaConversions.iterableAsScalaIterable(func.perform(x).asInstanceOf[java.util.ArrayList[Row]]))(encoder); + JavaConversions.iterableAsScalaIterable(func.perform(Seq(x)).asInstanceOf[java.util.ArrayList[Row]]))(encoder); out.write(output); } } diff --git a/piflow-core/src/main/scala/util/scriptengine.scala b/piflow-core/src/main/scala/util/scriptengine.scala index 98f563b..53406b8 100644 --- a/piflow-core/src/main/scala/util/scriptengine.scala +++ b/piflow-core/src/main/scala/util/scriptengine.scala @@ -3,20 +3,21 @@ package cn.piflow.util import java.util.{Map => JMap} import javax.script.{Compilable, ScriptEngineManager} +import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.collection.immutable.StringOps import scala.collection.mutable.{ArrayBuffer, Map => MMap} trait ScriptEngine { - def compile(funcText: String): CompiledFunction; + def compile(funcText: String, nums: Int = 1): CompiledFunction; } trait CompiledFunction { - def invoke(args: Map[String, Any] = Map[String, Any]()): Any; + def invoke(args: Seq[Any]): Any; } trait FunctionLogic { - def perform(value: Any): Any; + def perform(value: Seq[Any]): Any; } object ScriptEngine { @@ -27,7 +28,7 @@ object ScriptEngine { def logic(script: String, lang: String = ScriptEngine.JAVASCRIPT): FunctionLogic = new FunctionLogic with Serializable { val cached = ArrayBuffer[CompiledFunction](); - override def perform(value: Any): Any = { + override def perform(args: Seq[Any]): Any = { if (cached.isEmpty) { try { val engine = ScriptEngine.get(lang); @@ -35,16 +36,16 @@ object ScriptEngine { } catch { case e: Throwable => - throw new ScriptExecutionException(e, script, value); + throw new ScriptExecutionException(e, script, args); } } try { - cached(0).invoke(Map("value" -> value)); + cached(0).invoke(args); } catch { case e: Throwable => - throw new ScriptExecutionException(e, script, value); + throw new ScriptExecutionException(e, script, args); }; } } @@ -61,15 +62,16 @@ class JavaScriptEngine extends ScriptEngine { map.toMap; } - def compile(funcText: String): CompiledFunction = { - val wrapped = s"($funcText)(value)"; + def compile(funcText: String, nums: Int): CompiledFunction = { + val args = (1 to nums).map("arg" + _).mkString(","); + val wrapped = s"($funcText)($args)"; new CompiledFunction() { val compiled = engine.asInstanceOf[Compilable].compile(wrapped); - def invoke(args: Map[String, Any] = Map[String, Any]()): Any = { + def invoke(args: Seq[Any]): Any = { val bindings = engine.createBindings(); bindings.asInstanceOf[JMap[String, Any]].putAll(tools); - bindings.asInstanceOf[JMap[String, Any]].putAll(args); + bindings.asInstanceOf[JMap[String, Any]].putAll(JavaConversions.mapAsJavaMap(args.zip(1 to args.length).map(x => ("arg" + x._2, x._1)).toMap)); val value = compiled.eval(bindings); value; diff --git a/piflow-core/src/test/scala/FlowTest.scala b/piflow-core/src/test/scala/FlowTest.scala index c388f02..545ae9c 100644 --- a/piflow-core/src/test/scala/FlowTest.scala +++ b/piflow-core/src/test/scala/FlowTest.scala @@ -8,7 +8,6 @@ import cn.piflow.util.ScriptEngine import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.sql.SparkSession import org.junit.Test -import Path._ class FlowTest { @Test