diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..fa1f1ca --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,41 @@ +version: 2 + +references: + environment: &environment + docker: + - image: circleci/openjdk:8-jdk + working_directory: ~/repo + environment: + # Customize the JVM maximum heap limit + JVM_OPTS: -Xmx3200m + TERM: dumb + +jobs: + spotless: + <<: *environment + steps: + - checkout + - run: ./gradlew spotlessCheck + + build: + <<: *environment + steps: + - checkout + - run: ./gradlew publish + + ## Require an AWS account, and take some money to launch EMR, so it doesn't work now. + run_example: + <<: *environment + steps: + - checkout + - run: ./example/run.sh + +workflows: + version: 2 + + merge-before: + jobs: + - build + - spotless + + diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..14757b9 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,9 @@ +# https://scalameta.org/scalafmt/#Configuration + +style = IntelliJ +maxColumn = 160 +align = none +newlines.penalizeSingleSelectMultiArgList = false +newlines.alwaysBeforeElseAfterCurlyIf = true +newlines.alwaysBeforeTopLevelStatements = true + diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..19e55b1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ +0.0.2 (2018-10-29) +================== + +* [Experimental] Implement ecs_task.py> operator. (No document yet) +* [Fix] Stop correctly after task run to shutdown TransferManager after processing. + +0.0.1 (2018-10-23) +================== + +* First Release diff --git a/LICENSE b/LICENSE.txt similarity index 100% rename from LICENSE rename to LICENSE.txt diff --git a/README.md b/README.md index 1036c21..0bfb174 100644 --- a/README.md +++ b/README.md @@ -1,70 +1,169 @@ -# digdag-plugin-example -[![Jitpack](https://jitpack.io/v/myui/digdag-plugin-example.svg)](https://jitpack.io/#myui/digdag-plugin-example) [![Digdag](https://img.shields.io/badge/digdag-v0.9.12-brightgreen.svg)](https://github.com/treasure-data/digdag/releases/tag/v0.9.12) +# digdag-operator-ecs_task +[![Jitpack](https://jitpack.io/v/pro.civitaspo/digdag-operator-ecs_task.svg)](https://jitpack.io/#pro.civitaspo/digdag-operator-ecs_task) [![CircleCI](https://circleci.com/gh/civitaspo/digdag-operator-ecs_task.svg?style=shield)](https://circleci.com/gh/civitaspo/digdag-operator-ecs_task) [![Digdag](https://img.shields.io/badge/digdag-v0.9.31-brightgreen.svg)](https://github.com/treasure-data/digdag/releases/tag/v0.9.31) -# 1) build +digdag plugin for AWS ECS Task. -```sh -./gradlew publish -``` +# Overview -Artifacts are build on local repos: `./build/repo`. +- Plugin type: operator -# 2) run an example +# Usage -```sh -digdag selfupdate +```yaml +_export: + plugin: + repositories: + - https://jitpack.io + dependencies: + - pro.civitaspo:digdag-operator-ecs_task:0.0.2 + ecs_task: + auth_method: profile + ++step0: + sh>: echo '{"store_params":{"civi":"taspo"}}' | aws s3 cp - ${output} + ++step1: + ecs_task.run>: + def: + network_mode: Host + container_definitions: + - name: uploader + image: amazonlinux:2 + command: [yum, install, '-y', awscli] + essential: true + memory: 500 + cpu: 10 + family: hello_world + cluster: ${cluster} + count: 1 + result_s3_uri: ${output} + ++step2: + echo>: ${civi} -digdag run --project sample plugin.dig -p repos=`pwd`/build/repo ``` -You'll find the result of the task in `./sample/example.out`. +# Configuration ---- +## Remarks -# Writing your own plugin +- type `DurationParam` is strings matched `\s*(?:(?\d+)\s*d)?\s*(?:(?\d+)\s*h)?\s*(?:(?\d+)\s*m)?\s*(?:(?\d+)\s*s)?\s*`. + - The strings is used as `java.time.Duration`. -1. You need to implement [a Plugin class](https://github.com/myui/digdag-plugin-example/blob/master/src/main/java/io/digdag/plugin/example/ExamplePlugin.java) that implements `io.digdag.spi.Plugin`. +## Common Configuration -2. Then, list it on [io.digdag.spi.Plugin](https://github.com/myui/digdag-plugin-example/blob/master/src/main/resources/META-INF/services/io.digdag.spi.Plugin). The listed plugins are loaded by Digdag. +### System Options -You can optionally create Eclipse/Idea project files as follows: -```sh -gradle eclipse -gradle idea -``` +Define the below options on properties (which is indicated by `-c`, `--config`). + +- **ecs_task.allow_auth_method_env**: Indicates whether users can use **auth_method** `"env"` (boolean, default: `false`) +- **ecs_task.allow_auth_method_instance**: Indicates whether users can use **auth_method** `"instance"` (boolean, default: `false`) +- **ecs_task.allow_auth_method_profile**: Indicates whether users can use **auth_method** `"profile"` (boolean, default: `false`) +- **ecs_task.allow_auth_method_properties**: Indicates whether users can use **auth_method** `"properties"` (boolean, default: `false`) +- **ecs_task.assume_role_timeout_duration**: Maximum duration which server administer allows when users assume **role_arn**. (`DurationParam`, default: `1h`) + +### Secrets + +- **ecs_task.access_key_id**: The AWS Access Key ID (optional) +- **ecs_task.secret_access_key**: The AWS Secret Access Key (optional) +- **ecs_task.session_token**: The AWS session token. This is used only **auth_method** is `"session"` (optional) +- **ecs_task.role_arn**: The AWS Role to assume. (optional) +- **ecs_task.role_session_name**: The AWS Role Session Name when assuming the role. (default: `digdag-ecs_task-${session_uuid}`) +- **ecs_task.http_proxy.host**: proxy host (required if **use_http_proxy** is `true`) +- **ecs_task.http_proxy.port** proxy port (optional) +- **ecs_task.http_proxy.scheme** `"https"` or `"http"` (default: `"https"`) +- **ecs_task.http_proxy.user** proxy user (optional) +- **ecs_task.http_proxy.password**: http proxy password (optional) + +### Options + +- **auth_method**: name of mechanism to authenticate requests (`"basic"`, `"env"`, `"instance"`, `"profile"`, `"properties"`, `"anonymous"`, or `"session"`. default: `"basic"`) + - `"basic"`: uses access_key_id and secret_access_key to authenticate. + - `"env"`: uses AWS_ACCESS_KEY_ID (or AWS_ACCESS_KEY) and AWS_SECRET_KEY (or AWS_SECRET_ACCESS_KEY) environment variables. + - `"instance"`: uses EC2 instance profile. + - `"profile"`: uses credentials written in a file. Format of the file is as following, where `[...]` is a name of profile. + - **profile_file**: path to a profiles file. (string, default: given by `AWS_CREDENTIAL_PROFILES_FILE` environment varialbe, or ~/.aws/credentials). + - **profile_name**: name of a profile. (string, default: `"default"`) + - `"properties"`: uses aws.accessKeyId and aws.secretKey Java system properties. + - `"anonymous"`: uses anonymous access. This auth method can access only public files. + - `"session"`: uses temporary-generated access_key_id, secret_access_key and session_token. +- **use_http_proxy**: Indicate whether using when accessing AWS via http proxy. (boolean, default: `false`) +- **region**: The AWS region. (string, optional) +- **endpoint**: The AWS Service endpoint. (string, optional) + +## Configuration for `ecs_task.register>` operator + +- **ecs_task.register>**: The configuration is the same as the snake-cased [RegisterTaskDefinition API](https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RegisterTaskDefinition.html) (map, required) + +## Configuration for `ecs_task.run>` operator + +The configuration is the same as the snake-cased [RunTask API](https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html). + +In addition, the below configurations exist. -*Note:* _It's better to change the dependencies from `provided` to `compile` in [build.gradle](https://github.com/myui/digdag-plugin-example/blob/master/build.gradle) for creating idea/eclipse project config._ +- **def**: The definition for the task. The configuration is the same as `ecs_task.register>`'s one. (map, optional) + - **NOTE**: **task_definition** is required on the [RunTask API Doc](https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html), but it is not required if the **def** is defined. +- **result_s3_uri**: The S3 uri for the task result. (string, optional) + - **NOTE**: This configuration is used by `ecs_task.result>` operator, so the result content must follow the rule. +- **timeout**: Timeout duration for the task. (`DurationParam`, default: `15m`) -# Plugin Loading +## Configuration for `ecs_task.wait>` operator -Digdag loads pluigins from Maven repositories by configuring [plugin options](https://github.com/myui/digdag-plugin-example/blob/master/sample/plugin.dig). +- **cluster**: The short name or full ARN of the cluster that hosts the tasks. (string, required) +- **tasks**: A list of up to 100 task IDs or full ARN entries. (array of string, required) +- **timeout**: Timeout duration for the tasks. (`DurationParam`, default: `15m`) +- **condition**: The condition of tasks to wait. Available values are `"all"` or `"any"`. (string, default: `"all"`) +- **status**: The status of tasks to wait. Available values are `"PENDING"`, `"RUNNING"`, or `"STOPPED"` (string, default: `"STOPPED"`) +- **ignore_failure**: Ignore even if any tasks exit with the code except 0. (boolean, default: `false`) -You can use a local Maven repository (local FS, Amazon S3) or any public Maven repository ([Maven Central](http://search.maven.org/), [Sonatype](https://www.sonatype.com/), [Bintary](https://bintray.com/), [Jitpack](https://jitpack.io/)) for the plugin artifact repository. +## Configuration for `ecs_task.result>` operator -# Publishing your plugin using Github and Jitpack +- **ecs_task.result>**: S3 URI that the result is stored. (string, required) + - **NOTE**: The result content must follow the below rule. + - the format is json. + - the keys are `"subtask_config"`, `"export_params"`, `"store_params"`. + - the values are string to object map. + - the usage follows [Digdag Python API](https://docs.digdag.io/python_api.html), [Digdag Ruby API](https://docs.digdag.io/ruby_api.html). -[Jitpack](https://jitpack.io/) is useful for publishing your github repository as a maven repository. +# Development + +## Run an Example + +### 1) build ```sh -git tag v0.1.3 -git push origin v0.1.3 +./gradlew publish ``` -https://jitpack.io/#myui/digdag-plugin-example/v0.1.3 +Artifacts are build on local repos: `./build/repo`. -Now, you can load the artifact from a github repository in [a dig file](https://github.com/myui/digdag-plugin-example/blob/master/sample/plugin.dig) as follows: +### 2) get your aws profile +```sh +aws configure ``` -_export: - plugin: - repositories: - # - file://${repos} - - https://jitpack.io - dependencies: - # - io.digdag.plugin:digdag-plugin-example:0.1.3 - - com.github.myui:digdag-plugin-example:v0.1.3 + +### 3) run an example + +```sh +./example/run.sh +``` + +## (TODO) Run Tests + +```sh +./gradlew test ``` -# Further reading +# ChangeLog + +[CHANGELOG.md](./CHANGELOG.md) + +# License + +[Apache License 2.0](./LICENSE.txt) + +# Author + +@civitaspo -- [Operators](http://docs.digdag.io/operators.html) and [their implementations](https://github.com/treasure-data/digdag/tree/master/digdag-standards/src/main/java/io/digdag/standards/operator) diff --git a/build.gradle b/build.gradle index 2dfcab6..015849a 100644 --- a/build.gradle +++ b/build.gradle @@ -1,14 +1,16 @@ -apply plugin: 'java' -apply plugin: 'maven' -apply plugin: 'maven-publish' - -apply plugin: 'eclipse' -apply plugin: 'idea' +plugins { + id 'scala' + id 'maven-publish' + id 'com.github.johnrengelman.shadow' version '2.0.2' + id "com.diffplug.gradle.spotless" version "3.13.0" +} -group = 'io.digdag.plugin' -version = '0.1.3' +group = 'pro.civitaspo' +version = '0.0.2' -def digdagVersion = '0.9.12' +def digdagVersion = '0.9.31' +def scalaSemanticVersion = "2.12.6" +def depScalaVersion = "2.12" repositories { mavenCentral() @@ -18,29 +20,30 @@ repositories { } } -configurations { - provided -} - dependencies { - provided 'io.digdag:digdag-spi:' + digdagVersion - // provided 'io.digdag:digdag-standards:' + digdagVersion - provided 'io.digdag:digdag-plugin-utils:' + digdagVersion // this should be 'compile' once digdag 0.8.2 is released to a built-in repository + compile group: 'io.digdag', name: 'digdag-spi', version: digdagVersion + compile group: 'io.digdag', name: 'digdag-plugin-utils', version: digdagVersion + + // https://mvnrepository.com/artifact/org.scala-lang/scala-library + compile group: 'org.scala-lang', name: 'scala-library', version: scalaSemanticVersion + + ['ecs', 's3', 'sts'].each { svc -> + // https://mvnrepository.com/artifact/com.amazonaws/ + compile group: 'com.amazonaws', name: "aws-java-sdk-${svc}", version: '1.11.433' + } } -sourceSets { - main { - compileClasspath += configurations.provided - test.compileClasspath += configurations.provided - test.runtimeClasspath += configurations.provided +shadowJar { + classifier = null + dependencies { + exclude(dependency('io.digdag:.*')) } } publishing { publications { - mavenJava(MavenPublication) { - // artifactId 'project1-sample' - from components.java + shadow(MavenPublication) { publication -> + project.shadow.component(publication) } } repositories { @@ -50,20 +53,16 @@ publishing { } } +spotless { + scala { + scalafmt('1.5.1').configFile('.scalafmt.conf') + } +} + sourceCompatibility = 1.8 targetCompatibility = 1.8 -compileJava.options.encoding = 'UTF-8' -compileTestJava.options.encoding = 'UTF-8' - -tasks.withType(JavaCompile) { - options.compilerArgs << "-Xlint:unchecked" << "-Xlint:deprecation" -} - -javadoc { - options { - locale = 'en_US' - encoding = 'UTF-8' - } -} +compileScala.options.encoding = 'UTF-8' +compileTestScala.options.encoding = 'UTF-8' +compileScala.options.compilerArgs << "-Xlint:unchecked" << "-Xlint:deprecation" diff --git a/sample/.gitignore b/example/.gitignore similarity index 100% rename from sample/.gitignore rename to example/.gitignore index 10476a6..62d39cb 100644 --- a/sample/.gitignore +++ b/example/.gitignore @@ -1,5 +1,5 @@ /.digdag-wrapper .digdag *.pyc - example.out + diff --git a/example/digdag.properties b/example/digdag.properties new file mode 100644 index 0000000..52f7c95 --- /dev/null +++ b/example/digdag.properties @@ -0,0 +1,4 @@ +ecs_task.allow_auth_method_env=true +ecs_task.allow_auth_method_instance=true +ecs_task.allow_auth_method_profile=true +ecs_task.allow_auth_method_properties=true diff --git a/example/example.dig b/example/example.dig new file mode 100644 index 0000000..2bfd7b8 --- /dev/null +++ b/example/example.dig @@ -0,0 +1,29 @@ + _export: + plugin: + repositories: + - file://${repos} + # - https://jitpack.io + dependencies: + - pro.civitaspo:digdag-operator-ecs_task:0.0.2 + ecs_task: + auth_method: profile + ++step0: + sh>: echo '{"store_params":{"civi":"taspo"}}' | aws s3 cp - ${output} ++step1: + ecs_task.run>: + def: + network_mode: Host + container_definitions: + - name: uploader + image: amazonlinux:2 + command: [yum, install, '-y', awscli] + essential: true + memory: 500 + cpu: 10 + family: hello_world + cluster: ${cluster} + count: 1 + result_s3_uri: ${output} ++step2: + echo>: ${civi} diff --git a/example/run.sh b/example/run.sh new file mode 100755 index 0000000..fa1a646 --- /dev/null +++ b/example/run.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +ROOT=$(cd $(dirname $0)/..; pwd) +EXAMPLE_ROOT=$ROOT/example +LOCAL_MAVEN_REPO=$ROOT/build/repo + +CLUSTER="$1" +OUTPUT="$2" + +if [ -z "$CLUSTER" ]; then + echo "[ERROR] Set cluster as the first argument." + exit 1 +fi +if [ -z "$OUTPUT" ]; then + echo "[ERROR] Set output s3 URI as the second argument." + exit 1 +fi + +( + cd $EXAMPLE_ROOT + + ## to remove cache + rm -rfv .digdag + + ## run + digdag run example.dig -c digdag.properties -p repos=${LOCAL_MAVEN_REPO} -p output=${OUTPUT} -p cluster=${CLUSTER} --no-save +) diff --git a/sample/template.txt b/example/template.txt similarity index 100% rename from sample/template.txt rename to example/template.txt diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 30d399d..29953ea 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 01a2f81..e0b3fb8 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,5 @@ -#Wed Mar 02 12:59:56 PST 2016 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-2.11-bin.zip diff --git a/gradlew b/gradlew index 91a7e26..cccdd3d 100755 --- a/gradlew +++ b/gradlew @@ -1,4 +1,4 @@ -#!/usr/bin/env bash +#!/usr/bin/env sh ############################################################################## ## @@ -6,20 +6,38 @@ ## ############################################################################## -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" -warn ( ) { +warn () { echo "$*" } -die ( ) { +die () { echo echo "$*" echo @@ -30,6 +48,7 @@ die ( ) { cygwin=false msys=false darwin=false +nonstop=false case "`uname`" in CYGWIN* ) cygwin=true @@ -40,31 +59,11 @@ case "`uname`" in MINGW* ) msys=true ;; + NONSTOP* ) + nonstop=true + ;; esac -# For Cygwin, ensure paths are in UNIX format before anything is touched. -if $cygwin ; then - [ -n "$JAVA_HOME" ] && JAVA_HOME=`cygpath --unix "$JAVA_HOME"` -fi - -# Attempt to set APP_HOME -# Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi -done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >&- -APP_HOME="`pwd -P`" -cd "$SAVED" >&- - CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar # Determine the Java command to use to start the JVM. @@ -90,7 +89,7 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then MAX_FD_LIMIT=`ulimit -H -n` if [ $? -eq 0 ] ; then if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then @@ -114,6 +113,7 @@ fi if $cygwin ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` @@ -154,11 +154,19 @@ if $cygwin ; then esac fi -# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules -function splitJvmOpts() { - JVM_OPTS=("$@") +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " } -eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS -JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi -exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" +exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index aec9973..e95643d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -8,14 +8,14 @@ @rem Set local scope for the variables with windows NT shell if "%OS%"=="Windows_NT" setlocal -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= - set DIRNAME=%~dp0 if "%DIRNAME%" == "" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome @@ -46,10 +46,9 @@ echo location of your Java installation. goto fail :init -@rem Get command-line arguments, handling Windowz variants +@rem Get command-line arguments, handling Windows variants if not "%OS%" == "Windows_NT" goto win9xME_args -if "%@eval[2+2]" == "4" goto 4NT_args :win9xME_args @rem Slurp the command line arguments. @@ -60,11 +59,6 @@ set _SKIP=2 if "x%~1" == "x" goto execute set CMD_LINE_ARGS=%* -goto execute - -:4NT_args -@rem Get arguments from the 4NT Shell from JP Software -set CMD_LINE_ARGS=%$ :execute @rem Setup the command line diff --git a/sample/plugin.dig b/sample/plugin.dig deleted file mode 100644 index a988068..0000000 --- a/sample/plugin.dig +++ /dev/null @@ -1,17 +0,0 @@ -_export: - plugin: - repositories: - - file://${repos} - # - https://jitpack.io - dependencies: - - io.digdag.plugin:digdag-plugin-example:0.1.3 - # - com.github.myui:digdag-plugin-example:v0.1.3 - -+step1: - example>: template.txt - message: yes - path: example.out - -+step2: - hello>: "hello " - message: world diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 0000000..555c054 --- /dev/null +++ b/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'digdag-operator-ecs_task' + diff --git a/src/main/java/io/digdag/plugin/example/ExampleOperatorFactory.java b/src/main/java/io/digdag/plugin/example/ExampleOperatorFactory.java deleted file mode 100644 index 54b7dba..0000000 --- a/src/main/java/io/digdag/plugin/example/ExampleOperatorFactory.java +++ /dev/null @@ -1,56 +0,0 @@ -package io.digdag.plugin.example; - -import static java.nio.charset.StandardCharsets.UTF_8; -import io.digdag.client.config.Config; -import io.digdag.spi.Operator; -import io.digdag.spi.OperatorContext; -import io.digdag.spi.OperatorFactory; -import io.digdag.spi.TaskResult; -import io.digdag.spi.TemplateEngine; -import io.digdag.util.BaseOperator; - -import java.io.IOException; -import java.nio.file.Files; - -import com.google.common.base.Throwables; - -public class ExampleOperatorFactory implements OperatorFactory { - private final TemplateEngine templateEngine; - - public ExampleOperatorFactory(TemplateEngine templateEngine) { - this.templateEngine = templateEngine; - } - - public String getType() { - return "example"; - } - - @Override - public Operator newOperator(OperatorContext context) { - return new ExampleOperator(context); - } - - private class ExampleOperator extends BaseOperator { - public ExampleOperator(OperatorContext context) { - super(context); - } - - @Override - public TaskResult runTask() { - Config params = request.getConfig().mergeDefault( - request.getConfig().getNestedOrGetEmpty("example")); - - String message = workspace.templateCommand(templateEngine, params, "message", UTF_8); - String path = params.get("path", String.class); - - try { - Files.write(workspace.getPath(path), message.getBytes(UTF_8)); - } catch (IOException ex) { - throw Throwables.propagate(ex); - } - - return TaskResult.empty(request); - } - } - -} diff --git a/src/main/java/io/digdag/plugin/example/ExamplePlugin.java b/src/main/java/io/digdag/plugin/example/ExamplePlugin.java deleted file mode 100644 index ba6dd15..0000000 --- a/src/main/java/io/digdag/plugin/example/ExamplePlugin.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.digdag.plugin.example; - -import io.digdag.spi.OperatorFactory; -import io.digdag.spi.OperatorProvider; -import io.digdag.spi.Plugin; -import io.digdag.spi.TemplateEngine; - -import java.util.Arrays; -import java.util.List; - -import javax.inject.Inject; - -public class ExamplePlugin implements Plugin { - @Override - public Class getServiceProvider(Class type) { - if (type == OperatorProvider.class) { - return ExampleOperatorProvider.class.asSubclass(type); - } else { - return null; - } - } - - public static class ExampleOperatorProvider implements OperatorProvider { - @Inject - protected TemplateEngine templateEngine; - - @Override - public List get() { - return Arrays.asList(new ExampleOperatorFactory(templateEngine), - new HelloOperatorFactory(templateEngine)); - } - } -} diff --git a/src/main/java/io/digdag/plugin/example/HelloOperatorFactory.java b/src/main/java/io/digdag/plugin/example/HelloOperatorFactory.java deleted file mode 100644 index fd195a9..0000000 --- a/src/main/java/io/digdag/plugin/example/HelloOperatorFactory.java +++ /dev/null @@ -1,50 +0,0 @@ -package io.digdag.plugin.example; - -import io.digdag.client.config.Config; -import io.digdag.spi.Operator; -import io.digdag.spi.OperatorContext; -import io.digdag.spi.OperatorFactory; -import io.digdag.spi.TaskResult; -import io.digdag.spi.TemplateEngine; -import io.digdag.util.BaseOperator; - -public class HelloOperatorFactory implements OperatorFactory { - @SuppressWarnings("unused") - private final TemplateEngine templateEngine; - - public HelloOperatorFactory(TemplateEngine templateEngine) { - this.templateEngine = templateEngine; - } - - @Override - public String getType() { - return "hello"; - } - - @Override - public Operator newOperator(OperatorContext context) { - return new HelloOperator(context); - } - - private class HelloOperator extends BaseOperator { - - HelloOperator(OperatorContext context) { - super(context); - } - - @Override - public TaskResult runTask() { - //Config params = request.getConfig(); - Config params = request.getConfig().mergeDefault( - request.getConfig().getNestedOrGetEmpty("hello")); - - String message = params.get("_command", String.class); - message += params.get("message", String.class); - - System.out.println(message); - - return TaskResult.empty(request); - } - - } -} diff --git a/src/main/resources/META-INF/services/io.digdag.spi.Plugin b/src/main/resources/META-INF/services/io.digdag.spi.Plugin index 622a36f..eb7bef2 100644 --- a/src/main/resources/META-INF/services/io.digdag.spi.Plugin +++ b/src/main/resources/META-INF/services/io.digdag.spi.Plugin @@ -1,2 +1,2 @@ -io.digdag.plugin.example.ExamplePlugin +pro.civitaspo.digdag.plugin.ecs_task.EcsTaskPlugin diff --git a/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/run.sh b/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/run.sh new file mode 100644 index 0000000..9a36ac8 --- /dev/null +++ b/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/run.sh @@ -0,0 +1,80 @@ +#!/bin/sh + +## s3 path structure +# . +# ├── workspace +# │ ├── hoge.dig +# │   └── py +# │   └── hoge.py +# ├── in_file.json +# ├── out_file.json +# ├── run.sh +# ├── runner.py +# ├── stdout.log +# └── stderr.log + +## local path structure +# . +# ├── run.sh +# └── digdag-operator-ecs_task +# ├── workspace +# │ ├── hoge.dig +# │   └── py +# │   └── hoge.py +# ├── in.json +# ├── out.json +# ├── runner.py +# ├── stdout.log +# └── stderr.log + +set -ex +set -o pipefail + +mkdir -p ./digdag-operator-ecs_task +cd digdag-operator-ecs_task + +# Create output files +touch out.json stdout.log stderr.log + +# Download requirements +aws s3 cp s3://${ECS_TASK_PY_BUCKET}/${ECS_TASK_PY_PREFIX}/ ./ --recursive + +# Move workspace +cd workspace + +# Unset e option for returning python results to digdag +set +e + +# Run setup command +${ECS_TASK_PY_SETUP_COMMAND} \ + 2>> ../stderr.log \ + | tee -a ../stdout.log + +# Run +cat ../runner.py \ + | python - "${ECS_TASK_PY_COMMAND}" \ + ../in.json \ + ../out.json \ + 2>> ../stderr.log \ + | tee -a ../stdout.log + +# Capture exit code +EXIT_CODE=$? + +# Set e option +set -e + +# Move out workspace +cd .. + +# For logging driver +cat stderr.log 1>&2 + +# Upload results +aws s3 cp ./out.json s3://${ECS_TASK_PY_BUCKET}/${ECS_TASK_PY_PREFIX}/ +aws s3 cp ./stdout.log s3://${ECS_TASK_PY_BUCKET}/${ECS_TASK_PY_PREFIX}/ +aws s3 cp ./stderr.log s3://${ECS_TASK_PY_BUCKET}/${ECS_TASK_PY_PREFIX}/ + +# Exit with the python exit code +exit $EXIT_CODE + diff --git a/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/runner.py b/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/runner.py new file mode 100644 index 0000000..901df66 --- /dev/null +++ b/src/main/resources/pro/civitaspo/digdag/plugin/ecs_task/py/runner.py @@ -0,0 +1,183 @@ +######### +# Copy from https://raw.githubusercontent.com/treasure-data/digdag/52ff276bcc0aed23bf5a0df6c7a7c7b155c22d53/digdag-standards/src/main/resources/digdag/standards/py/runner.py +# Then, customize a bit about error handling +######### + +import collections +import imp +import inspect +import json +import os +import sys +import traceback + +command = sys.argv[1] +in_file = sys.argv[2] +out_file = sys.argv[3] + +with open(in_file) as f: + in_data = json.load(f) + params = in_data['params'] + +# fake digdag_env module already imported +digdag_env_mod = sys.modules['digdag_env'] = imp.new_module('digdag_env') +digdag_env_mod.params = params +digdag_env_mod.subtask_config = collections.OrderedDict() +digdag_env_mod.export_params = {} +digdag_env_mod.store_params = {} +digdag_env_mod.state_params = {} +import digdag_env + +# fake digdag module already imported +digdag_mod = sys.modules['digdag'] = imp.new_module('digdag') + + +class Env(object): + def __init__(self, digdag_env_mod): + self.params = digdag_env_mod.params + self.subtask_config = digdag_env_mod.subtask_config + self.export_params = digdag_env_mod.export_params + self.store_params = digdag_env_mod.store_params + self.state_params = digdag_env_mod.state_params + self.subtask_index = 0 + + def set_state(self, params={}, **kwds): + self.state_params.update(params) + self.state_params.update(kwds) + + def export(self, params={}, **kwds): + self.export_params.update(params) + self.export_params.update(kwds) + + def store(self, params={}, **kwds): + self.store_params.update(params) + self.store_params.update(kwds) + + def add_subtask(self, function=None, **params): + if function is not None and not isinstance(function, dict): + if hasattr(function, "im_class"): + # Python 2 + command = ".".join([function.im_class.__module__, function.im_class.__name__, function.__name__]) + else: + # Python 3 + command = ".".join([function.__module__, function.__name__]) + config = params + config["py>"] = command + else: + if isinstance(function, dict): + config = function.copy() + config.update(params) + else: + config = params + try: + json.dumps(config) + except Exception as error: + raise TypeError("Parameters must be serializable using JSON: %s" % str(error)) + self.subtask_config["+subtask" + str(self.subtask_index)] = config + self.subtask_index += 1 + + +digdag_mod.env = Env(digdag_env_mod) + +# add the archive path to improt path +sys.path.append(os.path.abspath(os.getcwd())) + + +def digdag_inspect_command(command): + # package.name.Class.method + fragments = command.split(".") + method_name = fragments.pop() + class_type = None + callable_type = None + try: + mod = __import__(".".join(fragments), fromlist=[method_name]) + try: + callable_type = getattr(mod, method_name) + except AttributeError as error: + raise AttributeError("Module '%s' has no attribute '%s'" % (".".join(fragments), method_name)) + except ImportError as error: + class_name = fragments.pop() + mod = __import__(".".join(fragments), fromlist=[class_name]) + try: + class_type = getattr(mod, class_name) + except AttributeError as error: + raise AttributeError("Module '%s' has no attribute '%s'" % (".".join(fragments), method_name)) + + if type(callable_type) == type: + class_type = callable_type + method_name = "run" + + if class_type is not None: + return (class_type, method_name) + else: + return (callable_type, None) + + +def digdag_inspect_arguments(callable_type, exclude_self, params): + if callable_type == object.__init__: + # object.__init__ accepts *varargs and **keywords but it throws exception + return {} + spec = inspect.getargspec(callable_type) + args = {} + for idx, key in enumerate(spec.args): + if exclude_self and idx == 0: + continue + if key in params: + args[key] = params[key] + else: + if spec.defaults is None or len(spec.defaults) < idx: + # this keyword is required but not in params. raising an error. + if hasattr(callable_type, '__qualname__'): + # Python 3 + name = callable_type.__qualname__ + elif hasattr(callable_type, 'im_class'): + # Python 2 + name = "%s.%s" % (callable_type.im_class.__name__, callable_type.__name__) + else: + name = callable_type.__name__ + raise TypeError("Method '%s' requires parameter '%s' but not set" % (name, key)) + if spec.keywords: + # above code was only for validation + return params + else: + return args + + +status_params = {} +def with_error_handler(func, **func_args): + try: + results = func(**func_args) + status_params['exit_code'] = 0 + return results + except Exception as e: + status_params['exit_code'] = 1 + status_params['error_message'] = str(e) + status_params['error_stacktrace'] = traceback.format_exc() + print('message: {}, stacktrace: {}', str(e), traceback.format_exc()) + +callable_type, method_name = digdag_inspect_command(command) + +if method_name: + init_args = digdag_inspect_arguments(callable_type.__init__, True, params) + instance = callable_type(**init_args) + + method = getattr(instance, method_name) + method_args = digdag_inspect_arguments(method, True, params) + # result = method(**method_args) + result = with_error_handler(method, **method_args) + +else: + args = digdag_inspect_arguments(callable_type, False, params) + # result = callable_type(**args) + result = with_error_handler(callable_type, **args) + +out = { + 'subtask_config': digdag_env.subtask_config, + 'export_params': digdag_env.export_params, + 'store_params': digdag_env.store_params, + 'status_params': status_params, # only for ecs_task.command_result_internal + # 'state_params': digdag_env.state_params, # only for retrying +} + +with open(out_file, 'w') as f: + json.dump(out, f) diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/AbstractEcsTaskOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/AbstractEcsTaskOperator.scala new file mode 100644 index 0000000..23e2e6c --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/AbstractEcsTaskOperator.scala @@ -0,0 +1,46 @@ +package pro.civitaspo.digdag.plugin.ecs_task +import io.digdag.client.config.{Config, ConfigFactory} +import io.digdag.spi.{OperatorContext, SecretProvider, TemplateEngine} +import io.digdag.util.{BaseOperator, DurationParam} +import org.slf4j.{Logger, LoggerFactory} +import pro.civitaspo.digdag.plugin.ecs_task.aws.{Aws, AwsConf} + +abstract class AbstractEcsTaskOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends BaseOperator(context) { + + protected val logger: Logger = LoggerFactory.getLogger(operatorName) + protected val cf: ConfigFactory = request.getConfig.getFactory + protected val params: Config = { + val elems: Seq[String] = operatorName.split("\\.") + elems.indices.foldLeft(request.getConfig) { (p: Config, idx: Int) => + p.mergeDefault((0 to idx).foldLeft(request.getConfig) { (nestedParam: Config, keyIdx: Int) => + nestedParam.getNestedOrGetEmpty(elems(keyIdx)) + }) + } + } + protected val secrets: SecretProvider = context.getSecrets.getSecrets("ecs_task") + protected val sessionUuid: String = params.get("session_uuid", classOf[String]) + + protected val aws: Aws = Aws( + AwsConf( + isAllowedAuthMethodEnv = systemConfig.get("ecs_task.allow_auth_method_env", classOf[Boolean], false), + isAllowedAuthMethodInstance = systemConfig.get("ecs_task.allow_auth_method_instance", classOf[Boolean], false), + isAllowedAuthMethodProfile = systemConfig.get("ecs_task.allow_auth_method_profile", classOf[Boolean], false), + isAllowedAuthMethodProperties = systemConfig.get("ecs_task.allow_auth_method_properties", classOf[Boolean], false), + assumeRoleTimeoutDuration = systemConfig.get("ecs_task.assume_role_timeout_duration", classOf[DurationParam], DurationParam.parse("1h")), + accessKeyId = secrets.getSecretOptional("access_key_id"), + secretAccessKey = secrets.getSecretOptional("secret_access_key"), + sessionToken = secrets.getSecretOptional("session_token"), + roleArn = secrets.getSecretOptional("role_arn"), + roleSessionName = secrets.getSecretOptional("role_session_name").or(s"digdag-ecs_task-$sessionUuid"), + httpProxy = secrets.getSecrets("http_proxy"), + authMethod = params.get("auth_method", classOf[String], "basic"), + profileName = params.get("profile_name", classOf[String], "default"), + profileFile = params.getOptional("profile_file", classOf[String]), + useHttpProxy = params.get("use_http_proxy", classOf[Boolean], false), + region = params.getOptional("region", classOf[String]), + endpoint = params.getOptional("endpoint", classOf[String]) + ) + ) + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/EcsTaskPlugin.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/EcsTaskPlugin.scala new file mode 100644 index 0000000..3b712ce --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/EcsTaskPlugin.scala @@ -0,0 +1,52 @@ +package pro.civitaspo.digdag.plugin.ecs_task + +import java.lang.reflect.Constructor +import java.util.{Arrays => JArrays, List => JList} + +import io.digdag.client.config.Config +import io.digdag.spi.{Operator, OperatorContext, OperatorFactory, OperatorProvider, Plugin, TemplateEngine} +import javax.inject.Inject +import pro.civitaspo.digdag.plugin.ecs_task.command.EcsTaskCommandResultInternalOperator +import pro.civitaspo.digdag.plugin.ecs_task.py.EcsTaskPyOperator +import pro.civitaspo.digdag.plugin.ecs_task.register.EcsTaskRegisterOperator +import pro.civitaspo.digdag.plugin.ecs_task.result.EcsTaskResultOperator +import pro.civitaspo.digdag.plugin.ecs_task.run.{EcsTaskRunInternalOperator, EcsTaskRunOperator} +import pro.civitaspo.digdag.plugin.ecs_task.wait.EcsTaskWaitOperator + +object EcsTaskPlugin { + + class EcsTaskOperatorProvider extends OperatorProvider { + + @Inject protected var systemConfig: Config = null + @Inject protected var templateEngine: TemplateEngine = null + + override def get(): JList[OperatorFactory] = { + JArrays.asList( + operatorFactory("ecs_task.py", classOf[EcsTaskPyOperator]), + operatorFactory("ecs_task.command_result_internal", classOf[EcsTaskCommandResultInternalOperator]), + operatorFactory("ecs_task.register", classOf[EcsTaskRegisterOperator]), + operatorFactory("ecs_task.result", classOf[EcsTaskResultOperator]), + operatorFactory("ecs_task.run", classOf[EcsTaskRunOperator]), + operatorFactory("ecs_task.run_internal", classOf[EcsTaskRunInternalOperator]), + operatorFactory("ecs_task.wait", classOf[EcsTaskWaitOperator]) + ) + } + + private def operatorFactory[T <: AbstractEcsTaskOperator](operatorName: String, klass: Class[T]): OperatorFactory = { + new OperatorFactory { + override def getType: String = operatorName + override def newOperator(context: OperatorContext): Operator = { + val constructor: Constructor[T] = klass.getConstructor(classOf[String], classOf[OperatorContext], classOf[Config], classOf[TemplateEngine]) + constructor.newInstance(operatorName, context, systemConfig, templateEngine) + } + } + } + } +} + +class EcsTaskPlugin extends Plugin { + override def getServiceProvider[T](`type`: Class[T]): Class[_ <: T] = { + if (`type` ne classOf[OperatorProvider]) return null + classOf[EcsTaskPlugin.EcsTaskOperatorProvider].asSubclass(`type`) + } +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AmazonS3UriWrapper.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AmazonS3UriWrapper.scala new file mode 100644 index 0000000..64fe16e --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AmazonS3UriWrapper.scala @@ -0,0 +1,6 @@ +package pro.civitaspo.digdag.plugin.ecs_task.aws +import com.amazonaws.services.s3.AmazonS3URI + +object AmazonS3UriWrapper { + def apply(path: String): AmazonS3URI = new AmazonS3URI(path, false) +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/Aws.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/Aws.scala new file mode 100644 index 0000000..6dbf7b1 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/Aws.scala @@ -0,0 +1,180 @@ +package pro.civitaspo.digdag.plugin.ecs_task.aws +import com.amazonaws.{ClientConfiguration, Protocol} +import com.amazonaws.auth.{ + AnonymousAWSCredentials, + AWSCredentials, + AWSCredentialsProvider, + AWSStaticCredentialsProvider, + BasicAWSCredentials, + BasicSessionCredentials, + EC2ContainerCredentialsProviderWrapper, + EnvironmentVariableCredentialsProvider, + SystemPropertiesCredentialsProvider +} +import com.amazonaws.auth.profile.{ProfileCredentialsProvider, ProfilesConfigFile} +import com.amazonaws.client.builder.AwsClientBuilder +import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration +import com.amazonaws.regions.{DefaultAwsRegionProviderChain, Regions} +import com.amazonaws.services.ecs.{AmazonECS, AmazonECSClientBuilder} +import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} +import com.amazonaws.services.s3.transfer.{TransferManager, TransferManagerBuilder} +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder +import com.amazonaws.services.securitytoken.model.AssumeRoleRequest +import com.google.common.base.Optional +import io.digdag.client.config.ConfigException + +import scala.util.Try + +case class Aws(conf: AwsConf) { + + def withS3[R](f: AmazonS3 => R): R = { + val s3: AmazonS3 = buildService(AmazonS3ClientBuilder.standard()) + try f(s3) + finally s3.shutdown() + } + + def withTransferManager[R](f: TransferManager => R): R = { + withS3 { s3 => + val xfer: TransferManager = TransferManagerBuilder.standard().withS3Client(s3).build() + try f(xfer) + finally xfer.shutdownNow(false) + } + } + + def withEcs[R](f: AmazonECS => R): R = { + val ecs: AmazonECS = buildService(AmazonECSClientBuilder.standard()) + try f(ecs) + finally ecs.shutdown() + } + + private def buildService[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): T = { + configureBuilderEndpointConfiguration(builder) + .withClientConfiguration(clientConfiguration) + .withCredentials(credentialsProvider) + .build() + } + + private def configureBuilderEndpointConfiguration[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): AwsClientBuilder[S, T] = { + if (conf.region.isPresent && conf.endpoint.isPresent) { + val ec = new EndpointConfiguration(conf.endpoint.get(), conf.region.get()) + builder.setEndpointConfiguration(ec) + } + else if (conf.region.isPresent && !conf.endpoint.isPresent) { + builder.setRegion(conf.region.get()) + } + else if (!conf.region.isPresent && conf.endpoint.isPresent) { + val r = Try(new DefaultAwsRegionProviderChain().getRegion).getOrElse(Regions.DEFAULT_REGION.getName) + val ec = new EndpointConfiguration(conf.endpoint.get(), r) + builder.setEndpointConfiguration(ec) + } + builder + } + + private def credentialsProvider: AWSCredentialsProvider = { + if (!conf.roleArn.isPresent) return standardCredentialsProvider + assumeRoleCredentialsProvider(standardCredentialsProvider) + } + + private def standardCredentialsProvider: AWSCredentialsProvider = { + conf.authMethod match { + case "basic" => basicAuthMethodAWSCredentialsProvider + case "env" => envAuthMethodAWSCredentialsProvider + case "instance" => instanceAuthMethodAWSCredentialsProvider + case "profile" => profileAuthMethodAWSCredentialsProvider + case "properties" => propertiesAuthMethodAWSCredentialsProvider + case "anonymous" => anonymousAuthMethodAWSCredentialsProvider + case "session" => sessionAuthMethodAWSCredentialsProvider + case _ => + throw new ConfigException( + s"""auth_method: "$conf.authMethod" is not supported. available `auth_method`s are "basic", "env", "instance", "profile", "properties", "anonymous", or "session".""" + ) + } + } + + private def assumeRoleCredentialsProvider(credentialsProviderToAssumeRole: AWSCredentialsProvider): AWSCredentialsProvider = { + // TODO: require EndpointConfiguration so on ? + val sts = AWSSecurityTokenServiceClientBuilder + .standard() + .withClientConfiguration(clientConfiguration) + .withCredentials(credentialsProviderToAssumeRole) + .build() + + val role = sts.assumeRole( + new AssumeRoleRequest() + .withRoleArn(conf.roleArn.get()) + .withDurationSeconds(conf.assumeRoleTimeoutDuration.getDuration.getSeconds.toInt) + .withRoleSessionName(conf.roleSessionName) + ) + val credentials = + new BasicSessionCredentials(role.getCredentials.getAccessKeyId, role.getCredentials.getSecretAccessKey, role.getCredentials.getSessionToken) + new AWSStaticCredentialsProvider(credentials) + } + + private def basicAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.accessKeyId.isPresent) throw new ConfigException(s"""`access_key_id` must be set when `auth_method` is "$conf.authMethod".""") + if (!conf.secretAccessKey.isPresent) throw new ConfigException(s"""`secret_access_key` must be set when `auth_method` is "$conf.authMethod".""") + val credentials: AWSCredentials = new BasicAWSCredentials(conf.accessKeyId.get(), conf.secretAccessKey.get()) + new AWSStaticCredentialsProvider(credentials) + } + + private def envAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.isAllowedAuthMethodEnv) throw new ConfigException(s"""auth_method: "$conf.authMethod" is not allowed.""") + new EnvironmentVariableCredentialsProvider + } + + private def instanceAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.isAllowedAuthMethodInstance) throw new ConfigException(s"""auth_method: "$conf.authMethod" is not allowed.""") + // NOTE: combination of InstanceProfileCredentialsProvider and ContainerCredentialsProvider + new EC2ContainerCredentialsProviderWrapper + } + + private def profileAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.isAllowedAuthMethodProfile) throw new ConfigException(s"""auth_method: "$conf.authMethod" is not allowed.""") + if (!conf.profileFile.isPresent) return new ProfileCredentialsProvider(conf.profileName) + val pf: ProfilesConfigFile = new ProfilesConfigFile(conf.profileFile.get()) + new ProfileCredentialsProvider(pf, conf.profileName) + } + + private def propertiesAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.isAllowedAuthMethodProperties) throw new ConfigException(s"""auth_method: "$conf.authMethod" is not allowed.""") + new SystemPropertiesCredentialsProvider() + } + + private def anonymousAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + val credentials: AWSCredentials = new AnonymousAWSCredentials + new AWSStaticCredentialsProvider(credentials) + } + + private def sessionAuthMethodAWSCredentialsProvider: AWSCredentialsProvider = { + if (!conf.accessKeyId.isPresent) throw new ConfigException(s"""`access_key_id` must be set when `auth_method` is "$conf.authMethod".""") + if (!conf.secretAccessKey.isPresent) throw new ConfigException(s"""`secret_access_key` must be set when `auth_method` is "$conf.authMethod".""") + if (!conf.sessionToken.isPresent) throw new ConfigException(s"""`session_token` must be set when `auth_method` is "$conf.authMethod".""") + val credentials: AWSCredentials = new BasicSessionCredentials(conf.accessKeyId.get(), conf.secretAccessKey.get(), conf.sessionToken.get()) + new AWSStaticCredentialsProvider(credentials) + } + + private def clientConfiguration: ClientConfiguration = { + if (!conf.useHttpProxy) return new ClientConfiguration() + + val host: String = conf.httpProxy.getSecret("host") + val port: Optional[String] = conf.httpProxy.getSecretOptional("port") + val protocol: Protocol = conf.httpProxy.getSecretOptional("scheme").or("https") match { + case "http" => Protocol.HTTP + case "https" => Protocol.HTTPS + case _ => throw new ConfigException(s"""`athena.http_proxy.scheme` must be "http" or "https".""") + } + val user: Optional[String] = conf.httpProxy.getSecretOptional("user") + val password: Optional[String] = conf.httpProxy.getSecretOptional("password") + + val cc = new ClientConfiguration() + .withProxyHost(host) + .withProtocol(protocol) + + if (port.isPresent) cc.setProxyPort(port.get().toInt) + if (user.isPresent) cc.setProxyUsername(user.get()) + if (password.isPresent) cc.setProxyPassword(password.get()) + + cc + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AwsConf.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AwsConf.scala new file mode 100644 index 0000000..d6b67cf --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/aws/AwsConf.scala @@ -0,0 +1,24 @@ +package pro.civitaspo.digdag.plugin.ecs_task.aws +import com.google.common.base.Optional +import io.digdag.spi.SecretProvider +import io.digdag.util.DurationParam + +case class AwsConf( + isAllowedAuthMethodEnv: Boolean, + isAllowedAuthMethodInstance: Boolean, + isAllowedAuthMethodProfile: Boolean, + isAllowedAuthMethodProperties: Boolean, + assumeRoleTimeoutDuration: DurationParam, + accessKeyId: Optional[String], + secretAccessKey: Optional[String], + sessionToken: Optional[String], + roleArn: Optional[String], + roleSessionName: String, + httpProxy: SecretProvider, + authMethod: String, + profileName: String, + profileFile: Optional[String], + useHttpProxy: Boolean, + region: Optional[String], + endpoint: Optional[String] +) diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandOperator.scala new file mode 100644 index 0000000..1c4ef4d --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandOperator.scala @@ -0,0 +1,17 @@ +package pro.civitaspo.digdag.plugin.ecs_task.command +import com.amazonaws.services.s3.AmazonS3URI +import io.digdag.spi.TaskResult + +trait EcsTaskCommandOperator { + + val runner: EcsTaskCommandRunner + + def additionalEnvironments(): Map[String, String] + + def uploadScript(): AmazonS3URI + + def runTask(): TaskResult = { + runner.run(scriptsLocationPrefix = uploadScript()) + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandResultInternalOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandResultInternalOperator.scala new file mode 100644 index 0000000..fbe59bb --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandResultInternalOperator.scala @@ -0,0 +1,81 @@ +package pro.civitaspo.digdag.plugin.ecs_task.command +import java.io.File + +import com.amazonaws.services.s3.AmazonS3URI +import com.amazonaws.services.s3.transfer.Download +import io.digdag.client.config.Config +import io.digdag.spi.{OperatorContext, TaskResult, TemplateEngine} +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator +import pro.civitaspo.digdag.plugin.ecs_task.aws.AmazonS3UriWrapper + +import scala.io.Source +import scala.util.{Failure, Try} + +class EcsTaskCommandResultInternalOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + protected val locationPrefix: AmazonS3URI = AmazonS3UriWrapper(params.get("_command", classOf[String])) + + override def runTask(): TaskResult = { + logStdoutStderr() + + val out: Config = loadOutJsonContent() + val statusParams: Config = out.getNested("status_params") + val exitCode: Int = statusParams.get("exit_code", classOf[Int]) + + if (exitCode != 0) { + val errorMessage: String = statusParams.get("error_message", classOf[String]) + val errorStackTrace: String = statusParams.get("error_stacktrace", classOf[String]) + throw new RuntimeException(s"message: $errorMessage, stacktrace: $errorStackTrace") + } + + TaskResult + .defaultBuilder(cf) + .subtaskConfig(out.getNestedOrGetEmpty("subtask_config")) + .exportParams(out.getNestedOrGetEmpty("export_params")) + .storeParams( + out + .getNestedOrGetEmpty("store_params") + .setNested("last_ecs_task_py", statusParams) + ) + .build() + } + + protected def loadOutJsonContent(): Config = { + val targetUri: AmazonS3URI = AmazonS3UriWrapper(s"$locationPrefix/out.json") + val content: String = loadS3ObjectContent(targetUri) + cf.fromJsonString(content) + } + + protected def logStdoutStderr(): Unit = { + val t: Try[Unit] = Try { // do nothing if failed + logger.info(s"stdout: ${loadStdoutLogContent()}") + logger.info(s"stderr: ${loadStderrLogContent()}") + } + t match { + case Failure(exception) => logger.error(exception.getMessage, exception) + case _ => // do nothing + } + } + + protected def loadStdoutLogContent(): String = { + val targetUri: AmazonS3URI = AmazonS3UriWrapper(s"$locationPrefix/stdout.log") + loadS3ObjectContent(targetUri) + } + + protected def loadStderrLogContent(): String = { + val targetUri: AmazonS3URI = AmazonS3UriWrapper(s"$locationPrefix/stderr.log") + loadS3ObjectContent(targetUri) + } + + protected def loadS3ObjectContent(uri: AmazonS3URI): String = { + val f: String = workspace.createTempFile("ecs_task.command_result_internal", ".txt") + logger.info(s"Download: $uri -> $f") + aws.withTransferManager { xfer => + val download: Download = xfer.download(uri.getBucket, uri.getKey, new File(f)) + download.waitForCompletion() + } + Source.fromFile(f).getLines.mkString("\n") + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandRunner.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandRunner.scala new file mode 100644 index 0000000..ec889c6 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/command/EcsTaskCommandRunner.scala @@ -0,0 +1,209 @@ +package pro.civitaspo.digdag.plugin.ecs_task.command +import com.amazonaws.services.s3.AmazonS3URI +import com.google.common.base.Optional +import io.digdag.client.config.{Config, ConfigFactory} +import io.digdag.spi.TaskResult +import io.digdag.util.DurationParam +import org.slf4j.Logger +import pro.civitaspo.digdag.plugin.ecs_task.aws.AwsConf + +import scala.collection.JavaConverters._ + +case class EcsTaskCommandRunner(params: Config, environments: Map[String, String], awsConf: AwsConf, logger: Logger) { + + val cf: ConfigFactory = params.getFactory + + // For ecs_task.register> operator (TaskDefinition) + // NOTE: Use only 1 container + // val containerDefinitions: Seq[ContainerDefinition] = params.getList("container_definitions", classOf[Config]).asScala.map(configureContainerDefinition).map(_.get) + val additionalContainers: Seq[Config] = params.getListOrEmpty("additional_containers", classOf[Config]).asScala + val cpu: Optional[String] = params.getOptional("cpu", classOf[String]) + val executionRoleArn: Optional[String] = params.getOptional("execution_role_arn", classOf[String]) + val family: String = params.get("family", classOf[String], params.get("task_name", classOf[String]).replaceAll("\\+", "_")) + val memory: Optional[String] = params.getOptional("memory", classOf[String]) + val networkMode: Optional[String] = params.getOptional("network_mode", classOf[String]) + // NOTE: Use `ecs_task.run>`'s one. + // val placementConstraints: Seq[TaskDefinitionPlacementConstraint] = params.getListOrEmpty("placement_constraints", classOf[Config]).asScala.map(configureTaskDefinitionPlacementConstraint).map(_.get) + val requiresCompatibilities: Seq[String] = params.getListOrEmpty("requires_compatibilities", classOf[String]).asScala // Valid Values: EC2 | FARGATE + val taskRoleArn: Optional[String] = params.getOptional("task_role_arn", classOf[String]) + val volumes: Seq[Config] = params.getListOrEmpty("volumes", classOf[Config]).asScala + + // For `ecs_task.register>` operator (ContainerDefinition) + // NOTE: Set by this plugin + // val command: Seq[String] = params.getListOrEmpty("command", classOf[String]).asScala + // NOTE: Set in `ecs_task.register>` TaskDefinition Context. If you set it by container level, use the `overrides` option. + // val cpu: Optional[Int] = params.getOptional("cpu", classOf[Int]) + val disableNetworking: Optional[Boolean] = params.getOptional("disable_networking", classOf[Boolean]) + val dnsSearchDomains: Seq[String] = params.getListOrEmpty("dns_search_domains", classOf[String]).asScala + val dnsServers: Seq[String] = params.getListOrEmpty("dns_servers", classOf[String]).asScala + // NOTE: Add some labels by this plugin + val dockerLabels: Map[String, String] = params.getMapOrEmpty("docker_labels", classOf[String], classOf[String]).asScala.toMap + val dockerSecurityOptions: Seq[String] = params.getListOrEmpty("docker_security_options", classOf[String]).asScala + val entryPoint: Seq[String] = params.getListOrEmpty("entry_point", classOf[String]).asScala + // NOTE: Add some envs by this plugin + val configEnvironment: Map[String, String] = params.getMapOrEmpty("environment", classOf[String], classOf[String]).asScala.toMap + // NOTE: This plugin uses only 1 container so `essential` is always true. + // val essential: Optional[Boolean] = params.getOptional("essential", classOf[Boolean]) + val extraHosts: Map[String, String] = params.getMapOrEmpty("extra_hosts", classOf[String], classOf[String]).asScala.toMap + val healthCheck: Optional[Config] = params.getOptionalNested("health_check") + val hostname: Optional[String] = params.getOptional("hostname", classOf[String]) + val image: Optional[String] = params.getOptional("image", classOf[String]) + val interactive: Optional[Boolean] = params.getOptional("interactive", classOf[Boolean]) + val links: Seq[String] = params.getListOrEmpty("links", classOf[String]).asScala + val linuxParameters: Optional[Config] = params.getOptionalNested("linux_parameters") + val logConfiguration: Optional[Config] = params.getOptionalNested("log_configuration") + // NOTE: Set in `ecs_task.register>` TaskDefinition Context. If you set it by container level, use the `overrides` option. + // val memory: Optional[Int] = params.getOptional("memory", classOf[Int]) + // NOTE: If you set it by container level, use the `overrides` option. + // val memoryReservation: Optional[Int] = params.getOptional("memory_reservation", classOf[Int]) + val mountPoints: Seq[Config] = params.getListOrEmpty("mount_points", classOf[Config]).asScala + val containerName: String = params.get("container_name", classOf[String], family) + val portMappings: Seq[Config] = params.getListOrEmpty("port_mappings", classOf[Config]).asScala + val privileged: Optional[Boolean] = params.getOptional("privileged", classOf[Boolean]) + val pseudoTerminal: Optional[Boolean] = params.getOptional("pseudo_terminal", classOf[Boolean]) + val readonlyRootFilesystem: Optional[Boolean] = params.getOptional("readonly_root_filesystem", classOf[Boolean]) + val repositoryCredentials: Optional[Config] = params.getOptionalNested("repository_credentials") + val systemControls: Seq[Config] = params.getListOrEmpty("system_controls", classOf[Config]).asScala + val ulimits: Seq[Config] = params.getListOrEmpty("ulimits", classOf[Config]).asScala + val user: Optional[String] = params.getOptional("user", classOf[String]) + val volumesFrom: Seq[Config] = params.getListOrEmpty("volumes_from", classOf[Config]).asScala + val workingDirectory: Optional[String] = params.getOptional("working_directory", classOf[String]) + + // For ecs_task.run operator + val cluster: String = params.get("cluster", classOf[String]) + val count: Optional[Int] = params.getOptional("count", classOf[Int]) + val group: Optional[String] = params.getOptional("group", classOf[String]) + val launchType: Optional[String] = params.getOptional("launch_type", classOf[String]) + val networkConfiguration: Optional[Config] = params.getOptionalNested("network_configuration") + val overrides: Optional[Config] = params.getOptionalNested("overrides") + val placementConstraints: Seq[Config] = params.getListOrEmpty("placement_constraints", classOf[Config]).asScala + val placementStrategy: Seq[Config] = params.getListOrEmpty("placement_strategy", classOf[Config]).asScala + val platformVersion: Optional[String] = params.getOptional("platform_version", classOf[String]) + val startedBy: Optional[String] = params.getOptional("started_by", classOf[String]) + // NOTE: Generated by ecs_task.register operator + // val taskDefinition: String = params.get("task_definition", classOf[String]) + + // For ecs_task.wait operator + val timeout: DurationParam = params.get("timeout", classOf[DurationParam], DurationParam.parse("15m")) + + def run(scriptsLocationPrefix: AmazonS3URI): TaskResult = { + val subTasks: Config = cf.create() + subTasks.setNested("+register", ecsTaskRegisterSubTask(scriptsLocationPrefix)) + subTasks.setNested("+run", ecsTaskRunInternalSubTask()) + subTasks.setNested("+wait", ecsTaskWaitSubTask()) + subTasks.setNested("+result", ecsTaskResultSubTask(scriptsLocationPrefix)) + + val builder = TaskResult.defaultBuilder(cf) + builder.subtaskConfig(subTasks) + builder.build() + } + + protected def ecsTaskRegisterSubTask(scriptsLocationPrefix: AmazonS3URI): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.register") + subTask.set("_command", taskDefinitionConfig(scriptsLocationPrefix)) + } + } + + protected def ecsTaskRunInternalSubTask(): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.run_internal") + subTask.set("cluster", cluster) + subTask.setOptional("count", count) + subTask.setOptional("group", group) + subTask.setOptional("launch_type", launchType) + subTask.setOptional("network_configuration", networkConfiguration) + subTask.setOptional("overrides", overrides) + subTask.set("placement_constraints", placementConstraints.asJava) + subTask.set("placement_strategy", placementStrategy.asJava) + subTask.setOptional("platform_version", platformVersion) + subTask.setOptional("started_by", startedBy) + subTask.set("task_definition", "${last_ecs_task_register.task_definition_arn}") + } + } + + protected def ecsTaskWaitSubTask(): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.wait") + subTask.set("cluster", cluster) + subTask.set("tasks", "${last_ecs_task_run.task_arns}") + subTask.set("timeout", timeout.toString) + subTask.set("ignore_failure", true) + } + } + + protected def ecsTaskResultSubTask(resultLocationPrefix: AmazonS3URI): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.command_result_internal") + subTask.set("_command", resultLocationPrefix.toString) + } + } + + protected def withDefaultSubTask(f: Config => Config): Config = { + val subTask: Config = cf.create() + + subTask.set("auth_method", awsConf.authMethod) + subTask.set("profile_name", awsConf.profileName) + if (awsConf.profileFile.isPresent) subTask.set("profile_file", awsConf.profileFile.get()) + subTask.set("use_http_proxy", awsConf.useHttpProxy) + if (awsConf.region.isPresent) subTask.set("region", awsConf.region.get()) + if (awsConf.endpoint.isPresent) subTask.set("endpoint", awsConf.endpoint.get()) + + f(subTask) + subTask + } + + protected def taskDefinitionConfig(scriptsLocationPrefix: AmazonS3URI): Config = { + val c: Config = cf.create() + + c.set("container_definitions", (Seq(containerDefinitionConfig(scriptsLocationPrefix)) ++ additionalContainers).asJava) + c.setOptional("cpu", cpu) + c.setOptional("execution_role_arn", executionRoleArn) + c.set("family", family) + c.setOptional("memory", memory) + c.setOptional("network_mode", networkMode) + c.set("requires_compatibilities", requiresCompatibilities.asJava) + c.setOptional("task_role_arn", taskRoleArn) + c.set("volumes", volumes.asJava) + + c + } + + protected def containerDefinitionConfig(scriptsLocationPrefix: AmazonS3URI): Config = { + val c: Config = cf.create() + + val command: Seq[String] = Seq("sh", "-c", s"aws s3 cp ${scriptsLocationPrefix.toString}/run.sh ./ && sh run.sh") + logger.info(s"Run in the container: ${command.mkString(" ")}") + c.set("command", command.asJava) + c.setOptional("disable_networking", disableNetworking) + c.set("dns_search_domains", dnsSearchDomains.asJava) + c.set("dns_servers", dnsServers.asJava) + val additionalLabels: Map[String, String] = Map("pro.civitaspo.digdag.plugin.ecs_task.version" -> "0.0.2") + c.set("docker_labels", (dockerLabels ++ additionalLabels).asJava) + c.set("entry_point", entryPoint.asJava) + c.set("environment", (configEnvironment ++ environments).asJava) + c.set("essential", true) + c.set("extra_hosts", extraHosts.asJava) + c.setOptional("health_check", healthCheck) + c.setOptional("image", image) + c.setOptional("interactive", interactive) + c.set("links", links.asJava) + c.setOptional("linux_parameters", linuxParameters) + c.setOptional("log_configuration", logConfiguration) + c.set("mount_points", mountPoints.asJava) + c.set("name", containerName) + c.set("port_mappings", portMappings.asJava) + c.setOptional("privileged", privileged) + c.setOptional("pseudo_terminal", pseudoTerminal) + c.setOptional("readonly_root_filesystem", readonlyRootFilesystem) + c.setOptional("repository_credentials", repositoryCredentials) + c.set("system_controls", systemControls.asJava) + c.set("ulimits", ulimits.asJava) + c.setOptional("user", user) + c.set("volumes_from", volumesFrom.asJava) + c.setOptional("working_directory", workingDirectory) + + c + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/py/EcsTaskPyOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/py/EcsTaskPyOperator.scala new file mode 100644 index 0000000..cbe2e62 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/py/EcsTaskPyOperator.scala @@ -0,0 +1,134 @@ +package pro.civitaspo.digdag.plugin.ecs_task.py +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files, Path} + +import com.amazonaws.services.s3.AmazonS3URI +import io.digdag.client.config.Config +import io.digdag.spi.{OperatorContext, TemplateEngine} +import org.apache.commons.io.FileUtils +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator +import pro.civitaspo.digdag.plugin.ecs_task.aws.AmazonS3UriWrapper +import pro.civitaspo.digdag.plugin.ecs_task.command.{EcsTaskCommandOperator, EcsTaskCommandRunner} + +import scala.collection.JavaConverters._ +import scala.io.Source +import scala.language.reflectiveCalls + +class EcsTaskPyOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) + with EcsTaskCommandOperator { + + private val runnerPyResourcePath: String = "/pro/civitaspo/digdag/plugin/ecs_task/py/runner.py" + private val runShResourcePath: String = "/pro/civitaspo/digdag/plugin/ecs_task/py/run.sh" + + protected val command: String = params.get("_command", classOf[String]) + protected val workspaceS3UriPrefix: AmazonS3URI = { + val parent: String = params.get("workspace_s3_uri_prefix", classOf[String]) + if (parent.endsWith("/")) AmazonS3UriWrapper(s"${parent}ecs_task.py.$sessionUuid") + else AmazonS3UriWrapper(s"$parent/ecs_task.py.$sessionUuid") + } + protected val pipInstall: Seq[String] = params.getListOrEmpty("pip_install", classOf[String]).asScala + + override val runner: EcsTaskCommandRunner = + EcsTaskCommandRunner(params = params, environments = additionalEnvironments(), awsConf = aws.conf, logger = logger) + + override def additionalEnvironments(): Map[String, String] = { + val vars = context.getPrivilegedVariables + val builder = Map.newBuilder[String, String] + vars.getKeys.asScala.foreach { k => + builder += (k -> vars.get(k)) + } + builder.result() + } + + override def uploadScript(): AmazonS3URI = { + withTempDir(operatorName) { tempDir: Path => + createInFile(tempDir) + createRunnerPyFile(tempDir) + createRunShFile(tempDir) + createWorkspaceDir(tempDir) + uploadOnS3(tempDir) + } + workspaceS3UriPrefix + } + + protected def createInFile(parent: Path): Unit = { + val inContent: String = templateEngine.template(cf.create.set("params", params).toString, params) + val inFile: Path = Files.createFile(parent.resolve("in.json")) + writeFile(file = inFile, content = inContent) + } + + protected def createRunnerPyFile(parent: Path): Unit = { + using(classOf[EcsTaskPyOperator].getResourceAsStream(runnerPyResourcePath)) { is => + val runnerPyContent: String = Source.fromInputStream(is).mkString + val runnerPyFile: Path = Files.createFile(parent.resolve("runner.py")) + writeFile(file = runnerPyFile, content = runnerPyContent) + } + } + + protected def createRunShFile(parent: Path): Unit = { + val dup: Config = params.deepCopy() + dup.set("ECS_TASK_PY_BUCKET", workspaceS3UriPrefix.getBucket) + dup.set("ECS_TASK_PY_PREFIX", workspaceS3UriPrefix.getKey) + dup.set("ECS_TASK_PY_COMMAND", command) + + dup.set("ECS_TASK_PY_SETUP_COMMAND", "echo 'no setup command'") // set a default value + if (pipInstall.nonEmpty) { + logger.warn("`pip_install` option is experimental, so please be careful in the plugin update.") + val cmd: String = (Seq("pip", "install") ++ pipInstall).mkString(" ") + dup.set("ECS_TASK_PY_SETUP_COMMAND", cmd) + } + + using(classOf[EcsTaskPyOperator].getResourceAsStream(runShResourcePath)) { is => + val runShContentTemplate: String = Source.fromInputStream(is).mkString + val runShContent: String = templateEngine.template(runShContentTemplate, dup) + val runShFile: Path = Files.createFile(parent.resolve("run.sh")) + writeFile(file = runShFile, content = runShContent) + } + } + + protected def createWorkspaceDir(parent: Path): Unit = { + val targets: Iterator[Path] = Files.list(workspace.getPath).iterator().asScala.filterNot(_.endsWith(".digdag")) + val workspacePath: Path = Files.createDirectory(parent.resolve("workspace")) + targets.foreach { path => + logger.info(s"Copy: $path -> $workspacePath") + if (Files.isDirectory(path)) FileUtils.copyDirectoryToDirectory(path.toFile, workspacePath.toFile) + else FileUtils.copyFileToDirectory(path.toFile, workspacePath.toFile) + } + } + + protected def uploadOnS3(path: Path): Unit = { + logger.info(s"Recursive Upload: $path -> ${workspaceS3UriPrefix.getURI}") + aws.withTransferManager { xfer => + val upload = xfer.uploadDirectory( + workspaceS3UriPrefix.getBucket, + workspaceS3UriPrefix.getKey, + path.toFile, + true // includeSubdirectories + ) + upload.waitForCompletion() + } + } + + protected def writeFile(file: Path, content: String): Unit = { + logger.info(s"Write into ${file.toString}") + using(workspace.newBufferedWriter(file.toString, UTF_8)) { writer => + writer.write(content) + } + } + + protected def using[A <: { def close() }, B](resource: A)(f: A => B): B = { + try f(resource) + finally resource.close() + } + + // ref. https://github.com/muga/digdag/blob/aff3dfab0b91aa6787d7921ce34d5b3b21947c20/digdag-plugin-utils/src/main/java/io/digdag/util/Workspace.java#L84-L95 + protected def withTempDir[T](prefix: String)(f: Path => T): T = { + val dir = workspace.getProjectPath.resolve(".digdag/tmp") + Files.createDirectories(dir) + val tempDir: Path = Files.createTempDirectory(dir, prefix) + try f(tempDir) + finally FileUtils.deleteDirectory(tempDir.toFile) + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/register/EcsTaskRegisterOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/register/EcsTaskRegisterOperator.scala new file mode 100644 index 0000000..475f352 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/register/EcsTaskRegisterOperator.scala @@ -0,0 +1,409 @@ +package pro.civitaspo.digdag.plugin.ecs_task.register +import com.amazonaws.services.ecs.model.{ + ContainerDefinition, + Device, + DockerVolumeConfiguration, + HealthCheck, + HostEntry, + HostVolumeProperties, + KernelCapabilities, + KeyValuePair, + LinuxParameters, + LogConfiguration, + LogDriver, + MountPoint, + NetworkMode, + PortMapping, + RegisterTaskDefinitionRequest, + RegisterTaskDefinitionResult, + RepositoryCredentials, + SystemControl, + TaskDefinitionPlacementConstraint, + Tmpfs, + TransportProtocol, + Ulimit, + UlimitName, + Volume, + VolumeFrom +} +import com.google.common.base.Optional +import com.google.common.collect.ImmutableList +import io.digdag.client.config.{Config, ConfigKey} +import io.digdag.spi.{ImmutableTaskResult, OperatorContext, TaskResult, TemplateEngine} +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator + +import scala.collection.JavaConverters._ + +class EcsTaskRegisterOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + protected val config: Config = params.getNested("_command") + + protected def buildRegisterTaskDefinitionRequest(c: Config): RegisterTaskDefinitionRequest = { + val req: RegisterTaskDefinitionRequest = new RegisterTaskDefinitionRequest() + + val containerDefinitions: Seq[ContainerDefinition] = + c.getList("container_definitions", classOf[Config]).asScala.map(configureContainerDefinition).map(_.get) + val cpu: Optional[String] = c.getOptional("cpu", classOf[String]) + val executionRoleArn: Optional[String] = c.getOptional("execution_role_arn", classOf[String]) + val family: String = c.get("family", classOf[String]) + val memory: Optional[String] = c.getOptional("memory", classOf[String]) + val networkMode: Optional[NetworkMode] = c.getOptional("network_mode", classOf[NetworkMode]) + + val placementConstraints: Seq[TaskDefinitionPlacementConstraint] = + c.getListOrEmpty("placement_constraints", classOf[Config]).asScala.map(configureTaskDefinitionPlacementConstraint).map(_.get) + val requiresCompatibilities: Seq[String] = c.getListOrEmpty("requires_compatibilities", classOf[String]).asScala // Valid Values: EC2 | FARGATE + val taskRoleArn: Optional[String] = c.getOptional("task_role_arn", classOf[String]) + val volumes: Seq[Volume] = c.getListOrEmpty("volumes", classOf[Config]).asScala.map(configureVolume).map(_.get) + + req.setContainerDefinitions(containerDefinitions.asJava) + if (cpu.isPresent) req.setCpu(cpu.get) + if (executionRoleArn.isPresent) req.setExecutionRoleArn(executionRoleArn.get) + req.setFamily(family) + if (memory.isPresent) req.setMemory(memory.get) + if (networkMode.isPresent) req.setNetworkMode(networkMode.get) + if (placementConstraints.nonEmpty) req.setPlacementConstraints(placementConstraints.asJava) + if (requiresCompatibilities.nonEmpty) req.setRequiresCompatibilities(requiresCompatibilities.asJava) + if (taskRoleArn.isPresent) req.setTaskRoleArn(taskRoleArn.get) + if (volumes.nonEmpty) req.setVolumes(volumes.asJava) + + req + } + + protected def configureContainerDefinition(c: Config): Optional[ContainerDefinition] = { + if (c.isEmpty) return Optional.absent() + + val command: Seq[String] = c.getListOrEmpty("command", classOf[String]).asScala + val cpu: Optional[Int] = c.getOptional("cpu", classOf[Int]) + val disableNetworking: Optional[Boolean] = c.getOptional("disable_networking", classOf[Boolean]) + val dnsSearchDomains: Seq[String] = c.getListOrEmpty("dns_search_domains", classOf[String]).asScala + val dnsServers: Seq[String] = c.getListOrEmpty("dns_servers", classOf[String]).asScala + val dockerLabels: Map[String, String] = c.getMapOrEmpty("docker_labels", classOf[String], classOf[String]).asScala.toMap + val dockerSecurityOptions: Seq[String] = c.getListOrEmpty("docker_security_options", classOf[String]).asScala + val entryPoint: Seq[String] = c.getListOrEmpty("entry_point", classOf[String]).asScala + val environment: Seq[KeyValuePair] = c + .getMapOrEmpty("environment", classOf[String], classOf[String]) + .asScala + .map { case (k: String, v: String) => new KeyValuePair().withName(k).withValue(v) } + .toSeq // TODO: doc + val essential: Optional[Boolean] = c.getOptional("essential", classOf[Boolean]) + val extraHosts: Seq[HostEntry] = c + .getMapOrEmpty("extra_hosts", classOf[String], classOf[String]) + .asScala + .map { case (host: String, ip: String) => new HostEntry().withHostname(host).withIpAddress(ip) } + .toSeq // TODO: doc + val healthCheck: Optional[HealthCheck] = configureHealthCheck(c.getNestedOrGetEmpty("health_check")) + val hostname: Optional[String] = c.getOptional("hostname", classOf[String]) + val image: Optional[String] = c.getOptional("image", classOf[String]) + val interactive: Optional[Boolean] = c.getOptional("interactive", classOf[Boolean]) + val links: Seq[String] = c.getListOrEmpty("links", classOf[String]).asScala + val linuxParameters: Optional[LinuxParameters] = configureLinuxParameters(c.getNestedOrGetEmpty("linux_parameters")) + val logConfiguration: Optional[LogConfiguration] = configureLogConfiguration(c.getNestedOrGetEmpty("log_configuration")) + val memory: Optional[Int] = c.getOptional("memory", classOf[Int]) + val memoryReservation: Optional[Int] = c.getOptional("memory_reservation", classOf[Int]) + val mountPoints: Seq[MountPoint] = c.getListOrEmpty("mount_points", classOf[Config]).asScala.map(configureMountPoint).map(_.get) + val name: Optional[String] = c.getOptional("name", classOf[String]) + val portMappings: Seq[PortMapping] = c.getListOrEmpty("port_mappings", classOf[Config]).asScala.map(configurePortMapping).map(_.get) + val privileged: Optional[Boolean] = c.getOptional("privileged", classOf[Boolean]) + val pseudoTerminal: Optional[Boolean] = c.getOptional("pseudo_terminal", classOf[Boolean]) + val readonlyRootFilesystem: Optional[Boolean] = c.getOptional("readonly_root_filesystem", classOf[Boolean]) + val repositoryCredentials: Optional[RepositoryCredentials] = configureRepositoryCredentials(c.getNestedOrGetEmpty("repository_credentials")) + val systemControls: Seq[SystemControl] = c.getListOrEmpty("system_controls", classOf[Config]).asScala.map(configureSystemControl).map(_.get) + val ulimits: Seq[Ulimit] = c.getListOrEmpty("ulimits", classOf[Config]).asScala.map(configureUlimit).map(_.get) + val user: Optional[String] = c.getOptional("user", classOf[String]) + val volumesFrom: Seq[VolumeFrom] = c.getListOrEmpty("volumes_from", classOf[Config]).asScala.map(configureVolumeFrom).map(_.get) + val workingDirectory: Optional[String] = c.getOptional("working_directory", classOf[String]) + + val cd: ContainerDefinition = new ContainerDefinition() + cd.setCommand(command.asJava) + if (cpu.isPresent) cd.setCpu(cpu.get) + if (disableNetworking.isPresent) cd.setDisableNetworking(disableNetworking.get) + if (dnsSearchDomains.nonEmpty) cd.setDnsSearchDomains(dnsSearchDomains.asJava) + if (dnsServers.nonEmpty) cd.setDnsServers(dnsServers.asJava) + if (dockerLabels.nonEmpty) cd.setDockerLabels(dockerLabels.asJava) + if (dockerSecurityOptions.nonEmpty) cd.setDockerSecurityOptions(dockerSecurityOptions.asJava) + if (entryPoint.nonEmpty) cd.setEntryPoint(entryPoint.asJava) + if (environment.nonEmpty) cd.setEnvironment(environment.asJava) // TODO: merge params? + if (essential.isPresent) cd.setEssential(essential.get) + if (extraHosts.nonEmpty) cd.setExtraHosts(extraHosts.asJava) + if (healthCheck.isPresent) cd.setHealthCheck(healthCheck.get) + if (hostname.isPresent) cd.setHostname(hostname.get) + if (image.isPresent) cd.setImage(image.get) + if (interactive.isPresent) cd.setInteractive(interactive.get) + if (links.nonEmpty) cd.setLinks(links.asJava) + if (linuxParameters.isPresent) cd.setLinuxParameters(linuxParameters.get) + if (logConfiguration.isPresent) cd.setLogConfiguration(logConfiguration.get) + if (memory.isPresent) cd.setMemory(memory.get) + if (memoryReservation.isPresent) cd.setMemoryReservation(memoryReservation.get) + if (mountPoints.nonEmpty) cd.setMountPoints(mountPoints.asJava) + if (name.isPresent) cd.setName(name.get) + if (portMappings.nonEmpty) cd.setPortMappings(portMappings.asJava) + if (privileged.isPresent) cd.setPrivileged(privileged.get) + if (pseudoTerminal.isPresent) cd.setPseudoTerminal(pseudoTerminal.get) + if (readonlyRootFilesystem.isPresent) cd.setReadonlyRootFilesystem(readonlyRootFilesystem.get) + if (repositoryCredentials.isPresent) cd.setRepositoryCredentials(repositoryCredentials.get) + if (systemControls.nonEmpty) cd.setSystemControls(systemControls.asJava) + if (ulimits.nonEmpty) cd.setUlimits(ulimits.asJava) + if (user.isPresent) cd.setUser(user.get) + if (volumesFrom.nonEmpty) cd.setVolumesFrom(volumesFrom.asJava) + if (workingDirectory.isPresent) cd.setWorkingDirectory(workingDirectory.get) + + Optional.of(cd) + } + + protected def configureHealthCheck(c: Config): Optional[HealthCheck] = { + if (c.isEmpty) return Optional.absent() + + val command: Seq[String] = params.getList("command", classOf[String]).asScala + val interval: Optional[Int] = params.getOptional("interval", classOf[Int]) + val retries: Optional[Int] = params.getOptional("retries", classOf[Int]) + val startPeriod: Optional[Int] = params.getOptional("start_period", classOf[Int]) + val timeout: Optional[Int] = params.getOptional("timeout", classOf[Int]) + + val hc: HealthCheck = new HealthCheck() + hc.setCommand(command.asJava) + if (interval.isPresent) hc.setInterval(interval.get) + if (retries.isPresent) hc.setRetries(retries.get) + if (startPeriod.isPresent) hc.setStartPeriod(startPeriod.get) + if (timeout.isPresent) hc.setTimeout(timeout.get) + + Optional.of(hc) + } + + protected def configureLinuxParameters(c: Config): Optional[LinuxParameters] = { + if (c.isEmpty) return Optional.absent() + + val capabilities: Optional[KernelCapabilities] = configureKernelCapabilities(c.getNestedOrGetEmpty("capabilities")) + val devices: Seq[Device] = c.getListOrEmpty("devices", classOf[Config]).asScala.map(configureDevice).map(_.get) + val initProcessEnabled: Optional[Boolean] = c.getOptional("init_process_enabled", classOf[Boolean]) + val sharedMemorySize: Optional[Int] = c.getOptional("shared_memory_size", classOf[Int]) + val tmpfs: Seq[Tmpfs] = c.getListOrEmpty("tmpfs", classOf[Config]).asScala.map(configureTmpfs).map(_.get) + + val lp: LinuxParameters = new LinuxParameters() + if (capabilities.isPresent) lp.setCapabilities(capabilities.get) + if (devices.nonEmpty) lp.setDevices(devices.asJava) + if (initProcessEnabled.isPresent) lp.setInitProcessEnabled(initProcessEnabled.get) + if (sharedMemorySize.isPresent) lp.setSharedMemorySize(sharedMemorySize.get) + if (tmpfs.nonEmpty) lp.setTmpfs(tmpfs.asJava) + + Optional.of(lp) + } + + protected def configureKernelCapabilities(c: Config): Optional[KernelCapabilities] = { + if (c.isEmpty) return Optional.absent() + + val add: Seq[String] = c.getListOrEmpty("add", classOf[String]).asScala + val drop: Seq[String] = c.getListOrEmpty("drop", classOf[String]).asScala + + val kc: KernelCapabilities = new KernelCapabilities() + if (add.nonEmpty) kc.setAdd(add.asJava) + if (drop.nonEmpty) kc.setDrop(drop.asJava) + + Optional.of(kc) + } + + protected def configureDevice(c: Config): Optional[Device] = { + if (c.isEmpty) return Optional.absent() + + val containerPath: Optional[String] = c.getOptional("container_path", classOf[String]) + val hostPath: String = c.get("host_path", classOf[String]) + val permissions: Seq[String] = c.getListOrEmpty("permissions", classOf[String]).asScala + + val d: Device = new Device() + if (containerPath.isPresent) d.setContainerPath(containerPath.get) + d.setHostPath(hostPath) + if (permissions.nonEmpty) d.setPermissions(permissions.asJava) + + Optional.of(d) + } + + protected def configureTmpfs(c: Config): Optional[Tmpfs] = { + if (c.isEmpty) return Optional.absent() + + val containerPath: String = c.get("container_path", classOf[String]) + val mountOptions: Seq[String] = c.getListOrEmpty("mount_options", classOf[String]).asScala + val size: Int = c.get("size", classOf[Int]) + + val tmpfs: Tmpfs = new Tmpfs() + tmpfs.setContainerPath(containerPath) + if (mountOptions.nonEmpty) tmpfs.setMountOptions(mountOptions.asJava) + tmpfs.setSize(size) + + Optional.of(tmpfs) + } + + protected def configureLogConfiguration(c: Config): Optional[LogConfiguration] = { + if (c.isEmpty) return Optional.absent() + + val logDriver: LogDriver = c.get("log_driver", classOf[LogDriver]) // Valid Values: json-file | syslog | journald | gelf | fluentd | awslogs | splunk + val options: Map[String, String] = c.getMapOrEmpty("options", classOf[String], classOf[String]).asScala.toMap + + val lc: LogConfiguration = new LogConfiguration() + lc.setLogDriver(logDriver) + if (options.nonEmpty) lc.setOptions(options.asJava) + + Optional.of(lc) + } + + protected def configureMountPoint(c: Config): Optional[MountPoint] = { + if (c.isEmpty) return Optional.absent() + + val containerPath: Optional[String] = c.getOptional("container_path", classOf[String]) + val readOnly: Optional[Boolean] = c.getOptional("read_only", classOf[Boolean]) + val sourceVolume: Optional[String] = c.getOptional("source_volume", classOf[String]) + + val mp: MountPoint = new MountPoint() + if (containerPath.isPresent) mp.setContainerPath(containerPath.get) + if (readOnly.isPresent) mp.setReadOnly(readOnly.get) + if (sourceVolume.isPresent) mp.setSourceVolume(sourceVolume.get) + + Optional.of(mp) + } + + protected def configurePortMapping(c: Config): Optional[PortMapping] = { + if (c.isEmpty) return Optional.absent() + + val containerPort: Optional[Int] = c.getOptional("container_port", classOf[Int]) + val hostPort: Optional[Int] = c.getOptional("host_port", classOf[Int]) + val protocol: Optional[TransportProtocol] = c.getOptional("protocol", classOf[TransportProtocol]) + + val pm: PortMapping = new PortMapping() + if (containerPort.isPresent) pm.setContainerPort(containerPort.get) + if (hostPort.isPresent) pm.setHostPort(hostPort.get) + if (protocol.isPresent) pm.setProtocol(protocol.get) + + Optional.of(pm) + } + + protected def configureRepositoryCredentials(c: Config): Optional[RepositoryCredentials] = { + if (c.isEmpty) return Optional.absent() + + val credentialsParameter: String = c.get("credentials_parameter", classOf[String]) + + val rc: RepositoryCredentials = new RepositoryCredentials() + rc.setCredentialsParameter(credentialsParameter) + + Optional.of(rc) + } + + protected def configureSystemControl(c: Config): Optional[SystemControl] = { + if (c.isEmpty) return Optional.absent() + + val namespace: Optional[String] = c.getOptional("namespace", classOf[String]) + val value: Optional[String] = c.getOptional("value", classOf[String]) + + val sc: SystemControl = new SystemControl() + if (namespace.isPresent) sc.setNamespace(namespace.get) + if (value.isPresent) sc.setValue(value.get) + + Optional.of(sc) + } + + protected def configureUlimit(c: Config): Optional[Ulimit] = { + if (c.isEmpty) return Optional.absent() + + val hardLimit: Int = c.get("hard_limit", classOf[Int]) + val name: UlimitName = c.get("name", classOf[UlimitName]) + val softLimit: Int = c.get("soft_limit", classOf[Int]) + + val u: Ulimit = new Ulimit() + u.setHardLimit(hardLimit) + u.setName(name) + u.setSoftLimit(softLimit) + + Optional.of(u) + } + + protected def configureVolumeFrom(c: Config): Optional[VolumeFrom] = { + if (c.isEmpty) return Optional.absent() + + val readOnly: Optional[Boolean] = c.getOptional("read_only", classOf[Boolean]) + val sourceContainer: Optional[String] = c.getOptional("source_container", classOf[String]) + + val vf: VolumeFrom = new VolumeFrom() + if (readOnly.isPresent) vf.setReadOnly(readOnly.get) + if (sourceContainer.isPresent) vf.setSourceContainer(sourceContainer.get) + + Optional.of(vf) + } + + protected def configureTaskDefinitionPlacementConstraint(c: Config): Optional[TaskDefinitionPlacementConstraint] = { + if (c.isEmpty) return Optional.absent() + + val expression: Optional[String] = c.getOptional("expression", classOf[String]) + val `type`: Optional[String] = c.getOptional("type", classOf[String]) + + val tdpc: TaskDefinitionPlacementConstraint = new TaskDefinitionPlacementConstraint() + if (expression.isPresent) tdpc.setExpression(expression.get) + if (`type`.isPresent) tdpc.setType(`type`.get) + + Optional.of(tdpc) + } + + protected def configureVolume(c: Config): Optional[Volume] = { + if (c.isEmpty) return Optional.absent() + + val dockerVolumeConfiguration: Optional[DockerVolumeConfiguration] = configureDockerVolumeConfiguration( + c.getNestedOrGetEmpty("docker_volume_configuration") + ) + val host: Optional[HostVolumeProperties] = configureHostVolumeProperties(c.getNestedOrGetEmpty("host")) + val name: Optional[String] = c.getOptional("name", classOf[String]) + + val v: Volume = new Volume() + if (dockerVolumeConfiguration.isPresent) v.setDockerVolumeConfiguration(dockerVolumeConfiguration.get) + if (host.isPresent) v.setHost(host.get) + if (name.isPresent) v.setName(name.get) + + Optional.of(v) + } + + protected def configureDockerVolumeConfiguration(c: Config): Optional[DockerVolumeConfiguration] = { + if (c.isEmpty) return Optional.absent() + + val autoprovision: Optional[Boolean] = c.getOptional("autoprovision", classOf[Boolean]) + val driver: Optional[String] = c.getOptional("driver", classOf[String]) + val driverOpts: Map[String, String] = c.getMapOrEmpty("driver_opts", classOf[String], classOf[String]).asScala.toMap + val labels: Map[String, String] = c.getMapOrEmpty("labels", classOf[String], classOf[String]).asScala.toMap + val scope: Optional[String] = c.getOptional("scope", classOf[String]) + + val dvc: DockerVolumeConfiguration = new DockerVolumeConfiguration() + if (autoprovision.isPresent) dvc.setAutoprovision(autoprovision.get) + if (driver.isPresent) dvc.setDriver(driver.get) + if (driverOpts.nonEmpty) dvc.setDriverOpts(driverOpts.asJava) + if (labels.nonEmpty) dvc.setLabels(labels.asJava) + if (scope.isPresent) dvc.setScope(scope.get) + + Optional.of(dvc) + } + + protected def configureHostVolumeProperties(c: Config): Optional[HostVolumeProperties] = { + if (c.isEmpty) return Optional.absent() + + val sourcePath: Optional[String] = c.getOptional("source_path", classOf[String]) + + val hvp: HostVolumeProperties = new HostVolumeProperties() + if (sourcePath.isPresent) hvp.setSourcePath(sourcePath.get) + + Optional.of(hvp) + } + + override def runTask(): TaskResult = { + val req: RegisterTaskDefinitionRequest = buildRegisterTaskDefinitionRequest(config) + logger.debug(req.toString) + val result: RegisterTaskDefinitionResult = aws.withEcs(_.registerTaskDefinition(req)) + logger.debug(result.toString) + + val paramsToStore = cf.create() + val last_ecs_task_register: Config = paramsToStore.getNestedOrSetEmpty("last_ecs_task_register") + last_ecs_task_register.set("task_definition_arn", result.getTaskDefinition.getTaskDefinitionArn) + last_ecs_task_register.set("family", result.getTaskDefinition.getFamily) + last_ecs_task_register.set("revision", result.getTaskDefinition.getRevision) + + val builder: ImmutableTaskResult.Builder = TaskResult.defaultBuilder(cf) + builder.resetStoreParams(ImmutableList.of(ConfigKey.of("last_ecs_task_register"))) + builder.storeParams(paramsToStore) + + builder.build() + } + +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/result/EcsTaskResultOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/result/EcsTaskResultOperator.scala new file mode 100644 index 0000000..f06f0ce --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/result/EcsTaskResultOperator.scala @@ -0,0 +1,34 @@ +package pro.civitaspo.digdag.plugin.ecs_task.result +import java.io.File + +import com.amazonaws.services.s3.AmazonS3URI +import com.amazonaws.services.s3.transfer.Download +import io.digdag.client.config.Config +import io.digdag.spi.{OperatorContext, TaskResult, TemplateEngine} +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator +import pro.civitaspo.digdag.plugin.ecs_task.aws.AmazonS3UriWrapper + +import scala.io.Source + +class EcsTaskResultOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + val s3Uri: AmazonS3URI = AmazonS3UriWrapper(params.get("_command", classOf[String])) + + override def runTask(): TaskResult = { + val f: String = workspace.createTempFile("ecs_task.result", ".json") + aws.withTransferManager { xfer => + val download: Download = xfer.download(s3Uri.getBucket, s3Uri.getKey, new File(f)) + download.waitForCompletion() + } + val content: String = Source.fromFile(f).getLines.mkString + val data: Config = cf.fromJsonString(content) + + TaskResult + .defaultBuilder(cf) + .subtaskConfig(data.getNestedOrGetEmpty("subtask_config")) + .exportParams(data.getNestedOrGetEmpty("export_params")) + .storeParams(data.getNestedOrGetEmpty("store_params")) + .build + } +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunInternalOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunInternalOperator.scala new file mode 100644 index 0000000..3358fb4 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunInternalOperator.scala @@ -0,0 +1,168 @@ +package pro.civitaspo.digdag.plugin.ecs_task.run +import com.amazonaws.services.ecs.model.{ + AwsVpcConfiguration, + ContainerOverride, + KeyValuePair, + NetworkConfiguration, + PlacementConstraint, + PlacementConstraintType, + PlacementStrategy, + PlacementStrategyType, + RunTaskRequest, + RunTaskResult, + TaskOverride +} +import com.google.common.base.Optional +import com.google.common.collect.ImmutableList +import io.digdag.client.config.{Config, ConfigKey} +import io.digdag.spi.{ImmutableTaskResult, OperatorContext, TaskResult, TemplateEngine} +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator + +import scala.collection.JavaConverters._ + +class EcsTaskRunInternalOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + val cluster: String = params.get("cluster", classOf[String]) + val count: Optional[Int] = params.getOptional("count", classOf[Int]) + val group: Optional[String] = params.getOptional("group", classOf[String]) + val launchType: Optional[String] = params.getOptional("launch_type", classOf[String]) + val networkConfiguration: Optional[NetworkConfiguration] = configureNetworkConfiguration(params.getNestedOrGetEmpty("network_configuration")) + val overrides: Optional[TaskOverride] = configureTaskOverride(params.getNestedOrGetEmpty("overrides")) + + val placementConstraints: Seq[PlacementConstraint] = + params.getListOrEmpty("placement_constraints", classOf[Config]).asScala.map(configurePlacementConstraint).map(_.get) + + val placementStrategy: Seq[PlacementStrategy] = + params.getListOrEmpty("placement_strategy", classOf[Config]).asScala.map(configurePlacementStrategy).map(_.get) + val platformVersion: Optional[String] = params.getOptional("platform_version", classOf[String]) + val startedBy: Optional[String] = params.getOptional("started_by", classOf[String]) + val taskDefinition: String = params.get("task_definition", classOf[String]) // generated by ecs_task.register> operator if not set. + + protected def buildRunTaskRequest(): RunTaskRequest = { + val req: RunTaskRequest = new RunTaskRequest() + + req.setCluster(cluster) + if (count.isPresent) req.setCount(count.get) + if (group.isPresent) req.setGroup(group.get) + if (launchType.isPresent) req.setLaunchType(launchType.get) + if (networkConfiguration.isPresent) req.setNetworkConfiguration(networkConfiguration.get) + if (overrides.isPresent) req.setOverrides(overrides.get) + if (placementConstraints.nonEmpty) req.setPlacementConstraints(placementConstraints.asJava) + if (placementStrategy.nonEmpty) req.setPlacementStrategy(placementStrategy.asJava) + if (platformVersion.isPresent) req.setPlatformVersion(platformVersion.get) + if (startedBy.isPresent) req.setStartedBy(startedBy.get) + req.setTaskDefinition(taskDefinition) + + req + } + + protected def configureNetworkConfiguration(c: Config): Optional[NetworkConfiguration] = { + if (c.isEmpty) return Optional.absent() + + val awsvpcConfiguration: Optional[AwsVpcConfiguration] = configureAwsVpcConfiguration(c.getNestedOrGetEmpty("awsvpc_configuration")) + + val nc: NetworkConfiguration = new NetworkConfiguration() + if (awsvpcConfiguration.isPresent) nc.setAwsvpcConfiguration(awsvpcConfiguration.get) + + Optional.of(nc) + } + + protected def configureAwsVpcConfiguration(c: Config): Optional[AwsVpcConfiguration] = { + if (c.isEmpty) return Optional.absent() + + val assignPublicIp: Optional[String] = c.getOptional("assign_public_ip", classOf[String]) + val securityGroups: Seq[String] = c.getListOrEmpty("security_groups", classOf[String]).asScala + val subnets: Seq[String] = c.getListOrEmpty("subnets", classOf[String]).asScala + + val avc: AwsVpcConfiguration = new AwsVpcConfiguration() + if (assignPublicIp.isPresent) avc.setAssignPublicIp(assignPublicIp.get) + if (securityGroups.nonEmpty) avc.setSecurityGroups(securityGroups.asJava) + if (subnets.nonEmpty) avc.setSubnets(subnets.asJava) + + Optional.of(avc) + } + + protected def configureTaskOverride(c: Config): Optional[TaskOverride] = { + if (c.isEmpty) return Optional.absent() + + val containerOverrides: Seq[ContainerOverride] = + c.getListOrEmpty("container_overrides", classOf[Config]).asScala.map(configureContainerOverride).map(_.get) + val executionRoleArn: Optional[String] = c.getOptional("execution_role_arn", classOf[String]) + val taskRoleArn: Optional[String] = c.getOptional("task_role_arn", classOf[String]) + + val to: TaskOverride = new TaskOverride() + if (containerOverrides.nonEmpty) to.setContainerOverrides(containerOverrides.asJava) + if (executionRoleArn.isPresent) to.setExecutionRoleArn(executionRoleArn.get) + if (taskRoleArn.isPresent) to.setTaskRoleArn(taskRoleArn.get) + + Optional.of(to) + } + + protected def configureContainerOverride(c: Config): Optional[ContainerOverride] = { + if (c.isEmpty) return Optional.absent() + + val command: Seq[String] = c.getListOrEmpty("command", classOf[String]).asScala + val cpu: Optional[Int] = c.getOptional("cpu", classOf[Int]) + val environment: Seq[KeyValuePair] = c + .getMapOrEmpty("environment", classOf[String], classOf[String]) + .asScala + .map { case (k: String, v: String) => new KeyValuePair().withName(k).withValue(v) } + .toSeq // TODO: doc + val memory: Optional[Int] = c.getOptional("memory", classOf[Int]) + val memoryReservation: Optional[Int] = c.getOptional("memory_reservation", classOf[Int]) + val name: Optional[String] = c.getOptional("name", classOf[String]) + + val co: ContainerOverride = new ContainerOverride() + if (command.nonEmpty) co.setCommand(command.asJava) + if (cpu.isPresent) co.setCpu(cpu.get) + if (environment.nonEmpty) co.setEnvironment(environment.asJava) + if (memory.isPresent) co.setMemory(memory.get) + if (memoryReservation.isPresent) co.setMemoryReservation(memoryReservation.get) + if (name.isPresent) co.setName(name.get) + + Optional.of(co) + } + + protected def configurePlacementConstraint(c: Config): Optional[PlacementConstraint] = { + if (c.isEmpty) return Optional.absent() + + val expression: Optional[String] = c.getOptional("expression", classOf[String]) + val `type`: Optional[PlacementConstraintType] = c.getOptional("type", classOf[PlacementConstraintType]) + + val pc: PlacementConstraint = new PlacementConstraint() + if (expression.isPresent) pc.setExpression(expression.get) + if (`type`.isPresent) pc.setType(`type`.get) + + Optional.of(pc) + } + + protected def configurePlacementStrategy(c: Config): Optional[PlacementStrategy] = { + if (c.isEmpty) return Optional.absent() + + val field: Optional[String] = c.getOptional("field", classOf[String]) + val `type`: Optional[PlacementStrategyType] = c.getOptional("type", classOf[PlacementStrategyType]) + + val ps: PlacementStrategy = new PlacementStrategy() + if (field.isPresent) ps.setField(field.get) + if (`type`.isPresent) ps.setType(`type`.get) + + Optional.of(ps) + } + + override def runTask(): TaskResult = { + val req: RunTaskRequest = buildRunTaskRequest() + logger.debug(req.toString) + val result: RunTaskResult = aws.withEcs(_.runTask(req)) + logger.debug(result.toString) + + val paramsToStore = cf.create() + val last_ecs_task_run: Config = paramsToStore.getNestedOrSetEmpty("last_ecs_task_run") + last_ecs_task_run.set("task_arns", result.getTasks.asScala.map(_.getTaskArn).asJava) + + val builder: ImmutableTaskResult.Builder = TaskResult.defaultBuilder(cf) + builder.resetStoreParams(ImmutableList.of(ConfigKey.of("last_ecs_task_run"))) + builder.storeParams(paramsToStore) + builder.build() + } +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunOperator.scala new file mode 100644 index 0000000..c8858f1 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/run/EcsTaskRunOperator.scala @@ -0,0 +1,74 @@ +package pro.civitaspo.digdag.plugin.ecs_task.run +import com.google.common.base.Optional +import io.digdag.client.config.Config +import io.digdag.spi.{OperatorContext, TaskResult, TemplateEngine} +import io.digdag.util.DurationParam +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator + +class EcsTaskRunOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + val cluster: String = params.get("cluster", classOf[String]) + val taskDef: Optional[Config] = params.getOptionalNested("def") + val resultS3Uri: Optional[String] = params.getOptional("result_s3_uri", classOf[String]) + val timeout: DurationParam = params.get("timeout", classOf[DurationParam], DurationParam.parse("15m")) + + override def runTask(): TaskResult = { + val subTasks: Config = cf.create() + if (taskDef.isPresent) subTasks.setNested("+register", ecsTaskRegisterSubTask()) + subTasks.setNested("+run", ecsTaskRunInternalSubTask()) + subTasks.setNested("+wait", ecsTaskWaitSubTask()) + if (resultS3Uri.isPresent) subTasks.setNested("+result", ecsTaskResultSubTask()) + + val builder = TaskResult.defaultBuilder(cf) + builder.subtaskConfig(subTasks) + builder.build() + } + + protected def ecsTaskRegisterSubTask(): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.register") + subTask.set("_command", taskDef) + } + } + + protected def ecsTaskRunInternalSubTask(): Config = { + val config: Config = params.deepCopy() + Seq("def", "result_s3_uri_prefix", "timeout").foreach(config.remove) + if (taskDef.isPresent) config.set("task_definition", "${last_ecs_task_register.task_definition_arn}") + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.run_internal") + subTask.set("_export", config) + } + } + + protected def ecsTaskWaitSubTask(): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.wait") + subTask.set("cluster", cluster) + subTask.set("tasks", "${last_ecs_task_run.task_arns}") + subTask.set("timeout", timeout.toString) + } + } + + protected def ecsTaskResultSubTask(): Config = { + withDefaultSubTask { subTask => + subTask.set("_type", "ecs_task.result") + subTask.set("_command", resultS3Uri.get) + } + } + + protected def withDefaultSubTask(f: Config => Config): Config = { + val subTask: Config = cf.create() + + subTask.set("auth_method", aws.conf.authMethod) + subTask.set("profile_name", aws.conf.profileName) + if (aws.conf.profileFile.isPresent) subTask.set("profile_file", aws.conf.profileFile.get()) + subTask.set("use_http_proxy", aws.conf.useHttpProxy) + if (aws.conf.region.isPresent) subTask.set("region", aws.conf.region.get()) + if (aws.conf.endpoint.isPresent) subTask.set("endpoint", aws.conf.endpoint.get()) + + f(subTask) + subTask + } +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaitOperator.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaitOperator.scala new file mode 100644 index 0000000..87ddf4f --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaitOperator.scala @@ -0,0 +1,61 @@ +package pro.civitaspo.digdag.plugin.ecs_task.wait +import com.amazonaws.services.ecs.model.{DescribeTasksRequest, DescribeTasksResult, Failure} +import io.digdag.client.config.Config +import io.digdag.spi.{OperatorContext, TaskResult, TemplateEngine} +import io.digdag.util.DurationParam +import pro.civitaspo.digdag.plugin.ecs_task.AbstractEcsTaskOperator + +import scala.collection.JavaConverters._ + +class EcsTaskWaitOperator(operatorName: String, context: OperatorContext, systemConfig: Config, templateEngine: TemplateEngine) + extends AbstractEcsTaskOperator(operatorName, context, systemConfig, templateEngine) { + + val cluster: String = params.get("cluster", classOf[String]) + val tasks: Seq[String] = params.parseList("tasks", classOf[String]).asScala + val timeout: DurationParam = params.get("timeout", classOf[DurationParam], DurationParam.parse("15m")) + val condition: String = params.get("condition", classOf[String], "all") + val status: String = params.get("status", classOf[String], "STOPPED") + val ignoreFailure: Boolean = params.get("ignore_failure", classOf[Boolean], false) + + override def runTask(): TaskResult = { + val req: DescribeTasksRequest = new DescribeTasksRequest() + .withCluster(cluster) + .withTasks(tasks: _*) + + aws.withEcs { ecs => + val waiter: EcsTaskWaiter = EcsTaskWaiter(logger = logger, ecs = ecs, timeout = timeout, condition = condition, status = status) + try waiter.wait(req) + finally waiter.shutdown() + } + val result: DescribeTasksResult = aws.withEcs(_.describeTasks(req)) + val failures: Seq[Failure] = result.getFailures.asScala + if (failures.nonEmpty) { + val failureMessages: String = failures.map(_.toString).mkString(", ") + if (!ignoreFailure) throw new IllegalStateException(s"Some tasks are failed: [$failureMessages]") + else logger.warn(s"Some tasks are failed but ignore them: $failureMessages") + } + + val failedMessages = Seq.newBuilder[String] + result.getTasks.asScala.foreach { task => + task.getContainers.asScala.foreach { container => + Option(container.getExitCode) match { + case Some(code) => + val msg = s"[${task.getTaskArn}] ${container.getName} has stopped with exit_code=$code" + logger.info(msg) + if (!code.equals(0)) failedMessages += msg + case None => + val msg = s"[${task.getTaskArn}] ${container.getName} has stopped without exit_code: reason=${container.getReason}" + logger.info(msg) + failedMessages += msg + } + } + } + if (failedMessages.result().nonEmpty) { + val message: String = failedMessages.result().mkString(", ") + if (!ignoreFailure) throw new IllegalStateException(s"Failure messages: $message") + else logger.warn(s"Some tasks are failed but ignore them: $message") + } + + TaskResult.empty(cf) + } +} diff --git a/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaiter.scala b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaiter.scala new file mode 100644 index 0000000..0acd590 --- /dev/null +++ b/src/main/scala/pro/civitaspo/digdag/plugin/ecs_task/wait/EcsTaskWaiter.scala @@ -0,0 +1,80 @@ +package pro.civitaspo.digdag.plugin.ecs_task.wait +import java.util.concurrent.{Executors, ExecutorService} + +import com.amazonaws.services.ecs.AmazonECS +import com.amazonaws.services.ecs.model.{DescribeTasksRequest, DescribeTasksResult} +import com.amazonaws.services.ecs.waiters.DescribeTasksFunction +import com.amazonaws.waiters.{ + FixedDelayStrategy, + MaxAttemptsRetryStrategy, + PollingStrategy, + Waiter, + WaiterAcceptor, + WaiterBuilder, + WaiterParameters, + WaiterState, + WaiterTimedOutException +} +import io.digdag.client.config.ConfigException +import io.digdag.util.DurationParam +import org.slf4j.Logger + +import scala.collection.JavaConverters._ + +case class EcsTaskWaiter( + logger: Logger, + ecs: AmazonECS, + executorService: ExecutorService = Executors.newFixedThreadPool(50), + timeout: DurationParam, + condition: String, + status: String +) { + + def wait(req: DescribeTasksRequest): Unit = { + newWaiter().run(new WaiterParameters[DescribeTasksRequest]().withRequest(req)) + } + + def shutdown(): Unit = { + executorService.shutdown() + } + + private def newWaiter(): Waiter[DescribeTasksRequest] = { + new WaiterBuilder[DescribeTasksRequest, DescribeTasksResult] + .withSdkFunction(new DescribeTasksFunction(ecs)) + .withAcceptors(newAcceptor()) + .withDefaultPollingStrategy(newPollingStrategy()) + .withExecutorService(executorService) + .build() + } + + private def newAcceptor(): WaiterAcceptor[DescribeTasksResult] = { + val startAt: Long = System.currentTimeMillis() + + new WaiterAcceptor[DescribeTasksResult] { + override def matches(output: DescribeTasksResult): Boolean = { + val waitingMillis: Long = System.currentTimeMillis() - startAt + logger.info( + s"Waiting ${waitingMillis}ms for that $condition tasks [${output.getTasks.asScala.map(t => s"${t.getTaskArn}:${t.getLastStatus}").mkString(",")}] become $status." + ) + if (waitingMillis > timeout.getDuration.toMillis) { + throw new WaiterTimedOutException(s"Reached timeout ${timeout.getDuration.toMillis}ms without transitioning to the desired state") + } + + condition match { + case "all" => output.getTasks.asScala.forall(t => t.getLastStatus.equals(status)) + case "any" => output.getTasks.asScala.exists(t => t.getLastStatus.equals(status)) + case _ => throw new ConfigException(s"condition: $condition is unsupported.") + } + } + override def getState: WaiterState = WaiterState.SUCCESS + } + } + + private def newPollingStrategy(): PollingStrategy = { + new PollingStrategy( + new MaxAttemptsRetryStrategy(Int.MaxValue), + new FixedDelayStrategy(1) // seconds + ) + } + +}