forked from opensci/piflow
script function accepts 1+ arguments
This commit is contained in:
parent
877b45c81b
commit
d734a6ec08
|
@ -52,7 +52,7 @@ class DoMap(func: FunctionLogic, targetSchema: StructType = null) extends Proces
|
|||
}
|
||||
};
|
||||
|
||||
val output = input.map(func.perform(_).asInstanceOf[Row])(encoder);
|
||||
val output = input.map(x => func.perform(Seq(x)).asInstanceOf[Row])(encoder);
|
||||
out.write(output);
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ class DoFlatMap(func: FunctionLogic, targetSchema: StructType = null) extends Pr
|
|||
};
|
||||
|
||||
val output = data.flatMap(x =>
|
||||
JavaConversions.iterableAsScalaIterable(func.perform(x).asInstanceOf[java.util.ArrayList[Row]]))(encoder);
|
||||
JavaConversions.iterableAsScalaIterable(func.perform(Seq(x)).asInstanceOf[java.util.ArrayList[Row]]))(encoder);
|
||||
out.write(output);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,20 +3,21 @@ package cn.piflow.util
|
|||
import java.util.{Map => JMap}
|
||||
import javax.script.{Compilable, ScriptEngineManager}
|
||||
|
||||
import scala.collection.JavaConversions
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.immutable.StringOps
|
||||
import scala.collection.mutable.{ArrayBuffer, Map => MMap}
|
||||
|
||||
trait ScriptEngine {
|
||||
def compile(funcText: String): CompiledFunction;
|
||||
def compile(funcText: String, nums: Int = 1): CompiledFunction;
|
||||
}
|
||||
|
||||
trait CompiledFunction {
|
||||
def invoke(args: Map[String, Any] = Map[String, Any]()): Any;
|
||||
def invoke(args: Seq[Any]): Any;
|
||||
}
|
||||
|
||||
trait FunctionLogic {
|
||||
def perform(value: Any): Any;
|
||||
def perform(value: Seq[Any]): Any;
|
||||
}
|
||||
|
||||
object ScriptEngine {
|
||||
|
@ -27,7 +28,7 @@ object ScriptEngine {
|
|||
def logic(script: String, lang: String = ScriptEngine.JAVASCRIPT): FunctionLogic = new FunctionLogic with Serializable {
|
||||
val cached = ArrayBuffer[CompiledFunction]();
|
||||
|
||||
override def perform(value: Any): Any = {
|
||||
override def perform(args: Seq[Any]): Any = {
|
||||
if (cached.isEmpty) {
|
||||
try {
|
||||
val engine = ScriptEngine.get(lang);
|
||||
|
@ -35,16 +36,16 @@ object ScriptEngine {
|
|||
}
|
||||
catch {
|
||||
case e: Throwable =>
|
||||
throw new ScriptExecutionException(e, script, value);
|
||||
throw new ScriptExecutionException(e, script, args);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
cached(0).invoke(Map("value" -> value));
|
||||
cached(0).invoke(args);
|
||||
}
|
||||
catch {
|
||||
case e: Throwable =>
|
||||
throw new ScriptExecutionException(e, script, value);
|
||||
throw new ScriptExecutionException(e, script, args);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -61,15 +62,16 @@ class JavaScriptEngine extends ScriptEngine {
|
|||
map.toMap;
|
||||
}
|
||||
|
||||
def compile(funcText: String): CompiledFunction = {
|
||||
val wrapped = s"($funcText)(value)";
|
||||
def compile(funcText: String, nums: Int): CompiledFunction = {
|
||||
val args = (1 to nums).map("arg" + _).mkString(",");
|
||||
val wrapped = s"($funcText)($args)";
|
||||
new CompiledFunction() {
|
||||
val compiled = engine.asInstanceOf[Compilable].compile(wrapped);
|
||||
|
||||
def invoke(args: Map[String, Any] = Map[String, Any]()): Any = {
|
||||
def invoke(args: Seq[Any]): Any = {
|
||||
val bindings = engine.createBindings();
|
||||
bindings.asInstanceOf[JMap[String, Any]].putAll(tools);
|
||||
bindings.asInstanceOf[JMap[String, Any]].putAll(args);
|
||||
bindings.asInstanceOf[JMap[String, Any]].putAll(JavaConversions.mapAsJavaMap(args.zip(1 to args.length).map(x => ("arg" + x._2, x._1)).toMap));
|
||||
|
||||
val value = compiled.eval(bindings);
|
||||
value;
|
||||
|
|
|
@ -8,7 +8,6 @@ import cn.piflow.util.ScriptEngine
|
|||
import org.apache.commons.io.{FileUtils, IOUtils}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.junit.Test
|
||||
import Path._
|
||||
|
||||
class FlowTest {
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue