diff --git a/src/main/scala/io/common.scala b/src/main/scala/io/common.scala index 46b48d9..36f7ccb 100644 --- a/src/main/scala/io/common.scala +++ b/src/main/scala/io/common.scala @@ -1,15 +1,15 @@ package cn.piflow.io -import cn.piflow.{DataSink, ProcessExecutionContext, _} +import cn.piflow.{ProcessExecutionContext, Sink, _} import org.apache.spark.sql._ -case class Console(nlimit: Int = 20) extends DataSink { +case class Console(nlimit: Int = 20) extends Sink { override def save(data: DataFrame, ctx: ProcessExecutionContext): Unit = { data.show(nlimit); } } -case class TextFile(path: String, format: String = FileFormat.TEXT) extends DataSource with DataSink { +case class TextFile(path: String, format: String = FileFormat.TEXT) extends Source with Sink { override def load(ctx: ProcessExecutionContext): DataFrame = { ctx.get[SparkSession].read.format(format).load(path).asInstanceOf[DataFrame]; } @@ -22,4 +22,5 @@ case class TextFile(path: String, format: String = FileFormat.TEXT) extends Data object FileFormat { val TEXT = "text"; val JSON = "json"; + val PARQUET = "parquet"; } \ No newline at end of file diff --git a/src/main/scala/process.scala b/src/main/scala/process.scala index 8df2c68..ad1f640 100644 --- a/src/main/scala/process.scala +++ b/src/main/scala/process.scala @@ -11,8 +11,6 @@ trait Process { def onCommit(pec: ProcessExecutionContext): Unit; def onRollback(pec: ProcessExecutionContext): Unit; - - def onFail(errorStage: ProcessStage, cause: Throwable, pec: ProcessExecutionContext): Unit; } abstract class LazyProcess extends Process with Logging { @@ -25,8 +23,6 @@ abstract class LazyProcess extends Process with Logging { def onRollback(pec: ProcessExecutionContext): Unit = { logger.warn(s"onRollback={}, process: $this"); } - - def onFail(errorStage: ProcessStage, cause: Throwable, pec: ProcessExecutionContext): Unit = {} } //TODO: one ProcessExecution with multiple RUNs @@ -38,10 +34,6 @@ trait ProcessExecution { def getProcessName(): String; def getProcess(): Process; - - def getStage(): ProcessStage; - - def handleError(jee: JobExecutionException): Unit; } trait ProcessExecutionContext extends Context { @@ -49,17 +41,21 @@ trait ProcessExecutionContext extends Context { def setStage(stage: ProcessStage): Unit; - def getStage(): ProcessStage; + def sendError(stage: ProcessStage, cause: Throwable): Unit; - def setErrorHandler(handler: ErrorHandler): Unit; + def getStage(): ProcessStage; } class ProcessExecutionContextImpl(processExecution: ProcessExecution, executionContext: FlowExecutionContext) - extends ProcessExecutionContext { + extends ProcessExecutionContext with Logging { val stages = ArrayBuffer[ProcessStage](); var errorHandler: ErrorHandler = Noop(); - def setStage(stage: ProcessStage) = stages += stage; + def setStage(stage: ProcessStage) = { + val processName = processExecution.getProcessName(); + logger.debug(s"stage changed: $stage, process: $processName"); + stages += stage + }; val context = MMap[String, Any](); @@ -67,6 +63,16 @@ class ProcessExecutionContextImpl(processExecution: ProcessExecution, executionC def getStage(): ProcessStage = stages.last; + def sendError(stage: ProcessStage, cause: Throwable) { + val processName = processExecution.getProcessName(); + val jee = new JobExecutionException(s"failed to execute process: $processName", cause); + logger.error { + s"failed to execute process: $processName, stage: $stage, cause: $cause" + }; + errorHandler.handle(jee); + throw jee; + } + override def get(key: String): Any = { if (context.contains(key)) context(key); @@ -78,8 +84,6 @@ class ProcessExecutionContextImpl(processExecution: ProcessExecution, executionC context(key) = value; this; }; - - override def setErrorHandler(handler: ErrorHandler): Unit = errorHandler = handler; } class ProcessAsQuartzJob extends Job with Logging { @@ -89,21 +93,9 @@ class ProcessAsQuartzJob extends Job with Logging { val executionContext = context.getScheduler.getContext.get("executionContext").asInstanceOf[FlowExecutionContext]; val pe = executionContext.runProcess(processName); - try { - pe.start(); - context.setResult(true); - } - catch { - case e => { - val jee = new JobExecutionException(s"failed to execute process: $processName", e); - logger.error { - val stage = pe.getStage(); - s"failed to execute process: $processName, stage: $stage, cause: $e" - }; - pe.handleError(jee); - throw jee; - }; - } + + pe.start(); + context.setResult(true); } } @@ -116,6 +108,7 @@ class ProcessExecutionImpl(processName: String, process: Process, executionConte override def start(): Unit = { try { + //prepare() processExecutionContext.setStage(PrepareStart()); process.onPrepare(processExecutionContext); processExecutionContext.setStage(PrepareComplete()); @@ -123,22 +116,23 @@ class ProcessExecutionImpl(processName: String, process: Process, executionConte catch { case e => try { + //rollback() logger.warn(s"onPrepare() failed: $e"); processExecutionContext.setStage(RollbackStart()); process.onRollback(processExecutionContext); processExecutionContext.setStage(RollbackComplete()); - - throw e; } catch { case e => logger.warn(s"onRollback() failed: $e"); - process.onFail(RollbackStart(), e, processExecutionContext); + processExecutionContext.sendError(RollbackStart(), e); + e.printStackTrace(); throw e; } } try { + //commit() processExecutionContext.setStage(CommitStart()); process.onCommit(processExecutionContext); processExecutionContext.setStage(CommitComplete()); @@ -146,7 +140,8 @@ class ProcessExecutionImpl(processName: String, process: Process, executionConte catch { case e => logger.warn(s"onCommit() failed: $e"); - process.onFail(CommitStart(), e, processExecutionContext); + processExecutionContext.sendError(CommitStart(), e); + e.printStackTrace(); throw e; } } @@ -157,10 +152,6 @@ class ProcessExecutionImpl(processName: String, process: Process, executionConte override def getProcessName(): String = processName; override def getProcess(): Process = process; - - override def handleError(jee: JobExecutionException): Unit = processExecutionContext.errorHandler.handle(jee); - - override def getStage(): ProcessStage = processExecutionContext.getStage(); } trait ErrorHandler { diff --git a/src/main/scala/sparkprocess.scala b/src/main/scala/sparkprocess.scala index 783a707..adce8ea 100644 --- a/src/main/scala/sparkprocess.scala +++ b/src/main/scala/sparkprocess.scala @@ -3,6 +3,9 @@ */ package cn.piflow +import java.io.File + +import cn.piflow.io.{Console, FileFormat, TextFile} import cn.piflow.util.{IdGenerator, Logging} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder @@ -12,22 +15,72 @@ import scala.collection.JavaConversions import scala.collection.mutable.{ArrayBuffer, Map => MMap} class SparkProcess extends Process with Logging { - val ends = ArrayBuffer[(ProcessExecutionContext) => Unit](); + + trait Ops { + def perform(ctx: ProcessExecutionContext): Unit; + } + + val ends = ArrayBuffer[Ops](); + + trait Backup { + def replica(): Sink; + + def restore(): Unit; + + def clean(): Unit; + } + + def createBackup(originalSink: Sink, ctx: ProcessExecutionContext): Backup = { + if (originalSink.isInstanceOf[Console]) { + new Backup() { + override def replica(): Sink = originalSink; + + override def clean(): Unit = {} + + override def restore(): Unit = {} + } + } + else { + new Backup() { + val backupFile = File.createTempFile(classOf[SparkProcess].getName, ".bak", + new File(ctx.get("localBackupDir").asInstanceOf[String])); + backupFile.delete(); + + def replica(): Sink = { + //TODO: hdfs + TextFile(backupFile.getAbsolutePath, FileFormat.PARQUET) + } + + def restore(): Unit = { + originalSink.save(TextFile(backupFile.getAbsolutePath, FileFormat.PARQUET).load(ctx), ctx); + } + + def clean(): Unit = { + backupFile.delete(); + } + } + } + } def onPrepare(pec: ProcessExecutionContext) = { - ends.foreach(_.apply(pec)); + val backup = ArrayBuffer[Backup](); + val ne = ends.map { x => + val so = x.asInstanceOf[SaveOps]; + val bu = createBackup(so.streamSink, pec); + backup += bu; + SaveOps(bu.replica(), so.stream); + } + + pec.put("backup", backup); + ne.foreach(_.perform(pec)); } override def onCommit(pec: ProcessExecutionContext): Unit = { - + pec.get("backup").asInstanceOf[ArrayBuffer[Backup]].foreach(_.restore()); } override def onRollback(pec: ProcessExecutionContext): Unit = { - - } - - override def onFail(errorStage: ProcessStage, cause: Throwable, pec: ProcessExecutionContext): Unit = { - + pec.get("backup").asInstanceOf[ArrayBuffer[Backup]].foreach(_.clean()); } abstract class CachedStream extends Stream { @@ -51,7 +104,7 @@ class SparkProcess extends Process with Logging { } } - def loadStream(streamSource: DataSource): Stream = { + def loadStream(streamSource: Source): Stream = { return new CachedStream() { override def produce(ctx: ProcessExecutionContext): DataFrame = { logger.debug { @@ -64,26 +117,30 @@ class SparkProcess extends Process with Logging { } } - def writeStream(streamSink: DataSink, stream: Stream): Unit = { - ends += { - (ctx: ProcessExecutionContext) => { - val input = stream.feed(ctx); - logger.debug { - val schema = input.schema; - val iid = stream.getId(); - s"saving stream[$iid->_], schema: $schema, sink: $streamSink"; - }; - streamSink.save(input, ctx); - } - }; + case class SaveOps(streamSink: Sink, stream: Stream) + extends Ops { + def perform(ctx: ProcessExecutionContext): Unit = { + val input = stream.feed(ctx); + logger.debug { + val schema = input.schema; + val iid = stream.getId(); + s"saving stream[$iid->_], schema: $schema, sink: $streamSink"; + }; + + streamSink.save(input, ctx); + } } - def transform(transformer: DataTransformer, streams: Stream*): Stream = { + def writeStream(streamSink: Sink, stream: Stream): Unit = { + ends += SaveOps(streamSink, stream); + } + + def transform(transformer: Transformer, streams: Stream*): Stream = { transform(transformer, streams.zipWithIndex.map(x => ("" + x._2, x._1)).toMap); } - def transform(transformer: DataTransformer, streams: Map[String, Stream]): Stream = { + def transform(transformer: Transformer, streams: Map[String, Stream]): Stream = { return new CachedStream() { override def produce(ctx: ProcessExecutionContext): DataFrame = { val inputs = streams.map(x => (x._1, x._2.feed(ctx))); @@ -109,15 +166,15 @@ trait Stream { } -trait DataSource { +trait Source { def load(ctx: ProcessExecutionContext): DataFrame; } -trait DataTransformer { +trait Transformer { def transform(data: Map[String, DataFrame], ctx: ProcessExecutionContext): DataFrame; } -trait DataTransformer1N1 extends DataTransformer { +trait Transformer1N1 extends Transformer { def transform(data: DataFrame, ctx: ProcessExecutionContext): DataFrame; def transform(dataset: Map[String, DataFrame], ctx: ProcessExecutionContext): DataFrame = { @@ -126,7 +183,7 @@ trait DataTransformer1N1 extends DataTransformer { } } -trait DataSink { +trait Sink { def save(data: DataFrame, ctx: ProcessExecutionContext): Unit; } @@ -134,7 +191,7 @@ trait FunctionLogic { def call(value: Any): Any; } -case class DoMap(func: FunctionLogic, targetSchema: StructType = null) extends DataTransformer1N1 { +case class DoMap(func: FunctionLogic, targetSchema: StructType = null) extends Transformer1N1 { def transform(data: DataFrame, ctx: ProcessExecutionContext): DataFrame = { val encoder = RowEncoder { if (targetSchema == null) { @@ -149,7 +206,7 @@ case class DoMap(func: FunctionLogic, targetSchema: StructType = null) extends D } } -case class DoFlatMap(func: FunctionLogic, targetSchema: StructType = null) extends DataTransformer1N1 { +case class DoFlatMap(func: FunctionLogic, targetSchema: StructType = null) extends Transformer1N1 { def transform(data: DataFrame, ctx: ProcessExecutionContext): DataFrame = { val encoder = RowEncoder { if (targetSchema == null) { @@ -165,7 +222,7 @@ case class DoFlatMap(func: FunctionLogic, targetSchema: StructType = null) exten } } -case class ExecuteSQL(sql: String) extends DataTransformer with Logging { +case class ExecuteSQL(sql: String) extends Transformer with Logging { def transform(dataset: Map[String, DataFrame], ctx: ProcessExecutionContext): DataFrame = { dataset.foreach { x => diff --git a/src/test/scala/FlowTest.scala b/src/test/scala/FlowTest.scala index 6bcedf0..6fd5939 100644 --- a/src/test/scala/FlowTest.scala +++ b/src/test/scala/FlowTest.scala @@ -13,7 +13,6 @@ class FlowTest { processes.foreach(en => flow.addProcess(en._1, en._2)); flow.addProcess("PrintMessage", new PrintMessage()); - flow.addTrigger("CopyTextFile", new DependencyTrigger("CleanHouse")); flow.addTrigger("CountWords", new DependencyTrigger("CopyTextFile")); flow.addTrigger("PrintCount", new DependencyTrigger("CountWords")); @@ -21,10 +20,14 @@ class FlowTest { val spark = SparkSession.builder.master("local[4]") .getOrCreate(); - val exe = Runner.run(flow, Map(classOf[SparkSession].getName -> spark)); + + val exe = Runner.run(flow, Map( + "localBackupDir" -> "/tmp/", + classOf[SparkSession].getName -> spark + )); exe.start("CleanHouse"); - Thread.sleep(20000); + Thread.sleep(30000); exe.stop(); } @@ -47,8 +50,6 @@ class FlowTest { override def onRollback(pec: ProcessExecutionContext): Unit = ??? - override def onFail(errorStage: ProcessStage, cause: Throwable, pec: ProcessExecutionContext): Unit = ??? - override def onCommit(pec: ProcessExecutionContext): Unit = ??? }, "CountWords" -> new CountWords(), @@ -117,7 +118,8 @@ class CountWords extends LazyProcess { .flatMap(s => s.zip(s.drop(1)).map(t => "" + t._1 + t._2)) .groupBy("value").count.sort($"count".desc); - val tmpfile = File.createTempFile(this.getClass.getSimpleName, ""); + val tmpfile = File.createTempFile(this.getClass.getName + "-", ""); + tmpfile.delete(); pec.put("tmpfile", tmpfile); count.write.json(tmpfile.getAbsolutePath);