diff --git a/.scalafmt.conf b/.scalafmt.conf deleted file mode 100644 index 14757b9..0000000 --- a/.scalafmt.conf +++ /dev/null @@ -1,9 +0,0 @@ -# https://scalameta.org/scalafmt/#Configuration - -style = IntelliJ -maxColumn = 160 -align = none -newlines.penalizeSingleSelectMultiArgList = false -newlines.alwaysBeforeElseAfterCurlyIf = true -newlines.alwaysBeforeTopLevelStatements = true - diff --git a/CHANGELOG.md b/CHANGELOG.md index d3c45b7..0a4fd65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +0.0.3 (2019-07-17) +================== + +* [New Feature] Add `catalog` option to register a new table that has data created by `s3_parquet` plugin. +* [Enhancement] Update dependencies. + 0.0.2 (2019-01-21) ================== diff --git a/README.md b/README.md index 704af82..9dae260 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,13 @@ - **role_external_id**: a unique identifier that is used by third parties when assuming roles in their customers' accounts. this is optionally used for **auth_method**: `"assume_role"`. (string, optional) - **role_session_duration_seconds**: duration, in seconds, of the role session. this is optionally used for **auth_method**: `"assume_role"`. (int, optional) - **scope_down_policy**: an iam policy in json format. this is optionally used for **auth_method**: `"assume_role"`. (string, optional) +- **catalog**: Register a table if this option is specified (optional) + - **catalog_id**: glue data catalog id if you use a catalog different from account/region default catalog. (string, optional) + - **database**: The name of the database (string, required) + - **table**: The name of the table (string, required) + - **column_options**: a key-value pairs where key is a column name and value is options for the column. (string to options map, default: `{}`) + - **type**: type of a column when this plugin creates new tables (e.g. `STRING`, `BIGINT`) (string, default: depends on input column type. `BIGINT` if input column type is `long`, `BOOLEAN` if boolean, `DOUBLE` if `double`, `STRING` if `string`, `STRING` if `timestamp`, `STRING` if `json`) + - **operation_if_exists**: operation if the table already exist. Available operations are `"delete"` and `"skip"` (string, default: `"delete"`) - **endpoint**: The AWS Service endpoint (string, optional) - **region**: The AWS region (string, optional) - **http_proxy**: Indicate whether using when accessing AWS via http proxy. (optional) diff --git a/build.gradle b/build.gradle index cc6d1e7..d385bb1 100644 --- a/build.gradle +++ b/build.gradle @@ -2,7 +2,6 @@ plugins { id "scala" id "com.jfrog.bintray" version "1.1" id "com.github.jruby-gradle.base" version "1.5.0" - id "com.diffplug.gradle.spotless" version "3.13.0" id "com.adarshr.test-logger" version "1.6.0" // For Pretty test logging } import com.github.jrubygradle.JRubyExec @@ -14,30 +13,30 @@ configurations { provided } -version = "0.0.2" +version = "0.0.3" sourceCompatibility = 1.8 targetCompatibility = 1.8 dependencies { - compile "org.embulk:embulk-core:0.9.12" - provided "org.embulk:embulk-core:0.9.12" + compile "org.embulk:embulk-core:0.9.17" + provided "org.embulk:embulk-core:0.9.17" - compile 'org.scala-lang:scala-library:2.12.8' - ['s3', 'sts'].each { v -> - compile "com.amazonaws:aws-java-sdk-${v}:1.11.479" + compile 'org.scala-lang:scala-library:2.13.0' + ['glue', 's3', 'sts'].each { v -> + compile "com.amazonaws:aws-java-sdk-${v}:1.11.592" } ['column', 'common', 'encoding', 'format', 'hadoop', 'jackson'].each { v -> - compile "org.apache.parquet:parquet-${v}:1.10.0" + compile "org.apache.parquet:parquet-${v}:1.10.1" } compile 'org.apache.hadoop:hadoop-common:2.9.2' - compile 'org.xerial.snappy:snappy-java:1.1.7.2' + compile 'org.xerial.snappy:snappy-java:1.1.7.3' - testCompile 'org.scalatest:scalatest_2.12:3.0.5' - testCompile 'org.embulk:embulk-test:0.9.12' - testCompile 'org.embulk:embulk-standards:0.9.12' + testCompile 'org.scalatest:scalatest_2.13:3.0.8' + testCompile 'org.embulk:embulk-test:0.9.17' + testCompile 'org.embulk:embulk-standards:0.9.17' testCompile 'cloud.localstack:localstack-utils:0.1.15' - testCompile 'org.apache.parquet:parquet-tools:1.8.0' + testCompile 'org.apache.parquet:parquet-tools:1.10.1' testCompile 'org.apache.hadoop:hadoop-client:2.9.2' } diff --git a/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala new file mode 100644 index 0000000..e6817cb --- /dev/null +++ b/src/main/scala/org/embulk/output/s3_parquet/CatalogRegistrator.scala @@ -0,0 +1,178 @@ +package org.embulk.output.s3_parquet + + +import java.util.{Optional, Map => JMap} + +import com.amazonaws.services.glue.model.{Column, CreateTableRequest, DeleteTableRequest, GetTableRequest, SerDeInfo, StorageDescriptor, TableInput} +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.embulk.config.{Config, ConfigDefault, ConfigException} +import org.embulk.output.s3_parquet.aws.Aws +import org.embulk.output.s3_parquet.CatalogRegistrator.ColumnOptions +import org.embulk.spi.Schema +import org.embulk.spi.`type`.{BooleanType, DoubleType, JsonType, LongType, StringType, TimestampType, Type} +import org.slf4j.{Logger, LoggerFactory} + +import scala.jdk.CollectionConverters._ +import scala.util.Try + + +object CatalogRegistrator +{ + trait Task + extends org.embulk.config.Task + { + @Config("catalog_id") + @ConfigDefault("null") + def getCatalogId: Optional[String] + + @Config("database") + def getDatabase: String + + @Config("table") + def getTable: String + + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, ColumnOptions] + + @Config("operation_if_exists") + @ConfigDefault("\"delete\"") + def getOperationIfExists: String + } + + trait ColumnOptions + { + @Config("type") + def getType: String + } + + def apply(aws: Aws, + task: Task, + schema: Schema, + location: String, + compressionCodec: CompressionCodecName, + loggerOption: Option[Logger] = None): CatalogRegistrator = + { + new CatalogRegistrator(aws, task, schema, location, compressionCodec, loggerOption) + } +} + +class CatalogRegistrator(aws: Aws, + task: CatalogRegistrator.Task, + schema: Schema, + location: String, + compressionCodec: CompressionCodecName, + loggerOption: Option[Logger] = None) +{ + val logger: Logger = loggerOption.getOrElse(LoggerFactory.getLogger(classOf[CatalogRegistrator])) + + def run(): Unit = + { + if (doesTableExists()) { + task.getOperationIfExists match { + case "skip" => + logger.info(s"Skip to register the table: ${task.getDatabase}.${task.getTable}") + return + + case "delete" => + logger.info(s"Delete the table: ${task.getDatabase}.${task.getTable}") + deleteTable() + + case unknown => + throw new ConfigException(s"Unsupported operation: $unknown") + } + } + registerNewParquetTable() + showNewTableInfo() + } + + def showNewTableInfo(): Unit = + { + val req = new GetTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + + val t = aws.withGlue(_.getTable(req)).getTable + logger.info(s"Created a table: ${t.toString}") + } + + def doesTableExists(): Boolean = + { + val req = new GetTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + + Try(aws.withGlue(_.getTable(req))).isSuccess + } + + def deleteTable(): Unit = + { + val req = new DeleteTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setName(task.getTable) + aws.withGlue(_.deleteTable(req)) + } + + def registerNewParquetTable(): Unit = + { + logger.info(s"Create a new table: ${task.getDatabase}.${task.getTable}") + val req = new CreateTableRequest() + task.getCatalogId.ifPresent(cid => req.setCatalogId(cid)) + req.setDatabaseName(task.getDatabase) + req.setTableInput(new TableInput() + .withName(task.getTable) + .withDescription("Created by embulk-output-s3_parquet") + .withTableType("EXTERNAL_TABLE") + .withParameters(Map("EXTERNAL" -> "TRUE", + "classification" -> "parquet", + "parquet.compression" -> compressionCodec.name()).asJava) + .withStorageDescriptor(new StorageDescriptor() + .withColumns(getGlueSchema: _*) + .withLocation(location) + .withCompressed(isCompressed) + .withInputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat") + .withOutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat") + .withSerdeInfo(new SerDeInfo() + .withSerializationLibrary("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe") + .withParameters(Map("serialization.format" -> "1").asJava) + ) + ) + ) + aws.withGlue(_.createTable(req)) + } + + private def getGlueSchema: Seq[Column] = + { + val columnOptions: Map[String, ColumnOptions] = task.getColumnOptions.asScala.toMap + schema.getColumns.asScala.toSeq.map { c => + val cType: String = + if (columnOptions.contains(c.getName)) columnOptions(c.getName).getType + else convertEmbulkType2GlueType(c.getType) + new Column() + .withName(c.getName) + .withType(cType) + } + } + + private def convertEmbulkType2GlueType(t: Type): String = + { + t match { + case _: BooleanType => "boolean" + case _: LongType => "bigint" + case _: DoubleType => "double" + case _: StringType => "string" + case _: TimestampType => "string" + case _: JsonType => "string" + case unknown => throw new ConfigException(s"Unsupported embulk type: ${unknown.getName}") + } + } + + private def isCompressed: Boolean = + { + !compressionCodec.equals(CompressionCodecName.UNCOMPRESSED) + } + +} diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala index 4eaf1b1..2c6b520 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetOutputPlugin.scala @@ -16,184 +16,206 @@ import org.embulk.spi.{Exec, OutputPlugin, PageReader, Schema, TransactionalPage import org.embulk.spi.time.TimestampFormatter import org.embulk.spi.time.TimestampFormatter.TimestampColumnOption import org.embulk.spi.util.Timestamps -import org.slf4j.Logger +import org.slf4j.{Logger, LoggerFactory} -object S3ParquetOutputPlugin { - trait PluginTask - extends Task - with TimestampFormatter.Task - with Aws.Task { +object S3ParquetOutputPlugin +{ - @Config("bucket") - def getBucket: String + trait PluginTask + extends Task + with TimestampFormatter.Task + with Aws.Task + { - @Config("path_prefix") - @ConfigDefault("\"\"") - def getPathPrefix: String + @Config("bucket") + def getBucket: String - @Config("sequence_format") - @ConfigDefault("\"%03d.%02d.\"") - def getSequenceFormat: String + @Config("path_prefix") + @ConfigDefault("\"\"") + def getPathPrefix: String - @Config("file_ext") - @ConfigDefault("\"parquet\"") - def getFileExt: String + @Config("sequence_format") + @ConfigDefault("\"%03d.%02d.\"") + def getSequenceFormat: String - @Config("compression_codec") - @ConfigDefault("\"uncompressed\"") - def getCompressionCodecString: String + @Config("file_ext") + @ConfigDefault("\"parquet\"") + def getFileExt: String - def setCompressionCodec(v: CompressionCodecName): Unit + @Config("compression_codec") + @ConfigDefault("\"uncompressed\"") + def getCompressionCodecString: String - def getCompressionCodec: CompressionCodecName + def setCompressionCodec(v: CompressionCodecName): Unit - @Config("column_options") - @ConfigDefault("{}") - def getColumnOptions: JMap[String, TimestampColumnOption] + def getCompressionCodec: CompressionCodecName - @Config("canned_acl") - @ConfigDefault("\"private\"") - def getCannedAclString: String + @Config("column_options") + @ConfigDefault("{}") + def getColumnOptions: JMap[String, TimestampColumnOption] - def setCannedAcl(v: CannedAccessControlList): Unit + @Config("canned_acl") + @ConfigDefault("\"private\"") + def getCannedAclString: String - def getCannedAcl: CannedAccessControlList + def setCannedAcl(v: CannedAccessControlList): Unit - @Config("block_size") - @ConfigDefault("null") - def getBlockSize: Optional[Int] + def getCannedAcl: CannedAccessControlList - @Config("page_size") - @ConfigDefault("null") - def getPageSize: Optional[Int] + @Config("block_size") + @ConfigDefault("null") + def getBlockSize: Optional[Int] - @Config("max_padding_size") - @ConfigDefault("null") - def getMaxPaddingSize: Optional[Int] + @Config("page_size") + @ConfigDefault("null") + def getPageSize: Optional[Int] - @Config("enable_dictionary_encoding") - @ConfigDefault("null") - def getEnableDictionaryEncoding: Optional[Boolean] + @Config("max_padding_size") + @ConfigDefault("null") + def getMaxPaddingSize: Optional[Int] - @Config("buffer_dir") - @ConfigDefault("null") - def getBufferDir: Optional[String] + @Config("enable_dictionary_encoding") + @ConfigDefault("null") + def getEnableDictionaryEncoding: Optional[Boolean] - } + @Config("buffer_dir") + @ConfigDefault("null") + def getBufferDir: Optional[String] + + @Config("catalog") + @ConfigDefault("null") + def getCatalog: Optional[CatalogRegistrator.Task] + } } class S3ParquetOutputPlugin - extends OutputPlugin { - - val logger: Logger = Exec.getLogger(classOf[S3ParquetOutputPlugin]) - - private def withPluginContextClassLoader[A](f: => A): A = { - val original: ClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classOf[S3ParquetOutputPlugin].getClassLoader) - try f - finally Thread.currentThread.setContextClassLoader(original) - } - - override def transaction(config: ConfigSource, - schema: Schema, - taskCount: Int, - control: OutputPlugin.Control): ConfigDiff = { - val task: PluginTask = config.loadConfig(classOf[PluginTask]) - - withPluginContextClassLoader { - configure(task, schema) - control.run(task.dump) - } + extends OutputPlugin +{ - Exec.newConfigDiff - } + val logger: Logger = LoggerFactory.getLogger(classOf[S3ParquetOutputPlugin]) + + private def withPluginContextClassLoader[A](f: => A): A = + { + val original: ClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classOf[S3ParquetOutputPlugin].getClassLoader) + try f + finally Thread.currentThread.setContextClassLoader(original) + } - private def configure(task: PluginTask, - schema: Schema): Unit = { - // sequence_format - try String.format(task.getSequenceFormat, 0: Integer, 0: Integer) - catch { - case e: IllegalFormatException => throw new ConfigException(s"Invalid sequence_format: ${task.getSequenceFormat}", e) + override def transaction(config: ConfigSource, + schema: Schema, + taskCount: Int, + control: OutputPlugin.Control): ConfigDiff = + { + val task: PluginTask = config.loadConfig(classOf[PluginTask]) + + withPluginContextClassLoader { + configure(task, schema) + control.run(task.dump) + } + task.getCatalog.ifPresent { catalog => + val location = s"s3://${task.getBucket}/${task.getPathPrefix.replaceFirst("(.*/)[^/]+$", "$1")}" + val cr = CatalogRegistrator(aws = Aws(task), + task = catalog, + schema = schema, + location = location, + compressionCodec = task.getCompressionCodec) + cr.run() + } + + Exec.newConfigDiff } - // compression_codec - CompressionCodecName.values().find(v => v.name().toLowerCase(Locale.ENGLISH).equals(task.getCompressionCodecString)) match { - case Some(v) => task.setCompressionCodec(v) - case None => - val unsupported: String = task.getCompressionCodecString - val supported: String = CompressionCodecName.values().map(v => s"'${v.name().toLowerCase}'").mkString(", ") - throw new ConfigException(s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported].") + private def configure(task: PluginTask, + schema: Schema): Unit = + { + // sequence_format + try String.format(task.getSequenceFormat, 0: Integer, 0: Integer) + catch { + case e: IllegalFormatException => throw new ConfigException(s"Invalid sequence_format: ${task.getSequenceFormat}", e) + } + + // compression_codec + CompressionCodecName.values().find(v => v.name().toLowerCase(Locale.ENGLISH).equals(task.getCompressionCodecString)) match { + case Some(v) => task.setCompressionCodec(v) + case None => + val unsupported: String = task.getCompressionCodecString + val supported: String = CompressionCodecName.values().map(v => s"'${v.name().toLowerCase}'").mkString(", ") + throw new ConfigException(s"'$unsupported' is unsupported: `compression_codec` must be one of [$supported].") + } + + // column_options + task.getColumnOptions.forEach { (k: String, + _) => + val c = schema.lookupColumn(k) + if (!c.getType.getName.equals("timestamp")) throw new ConfigException(s"column:$k is not 'timestamp' type.") + } + + // canned_acl + CannedAccessControlList.values().find(v => v.toString.equals(task.getCannedAclString)) match { + case Some(v) => task.setCannedAcl(v) + case None => + val unsupported: String = task.getCannedAclString + val supported: String = CannedAccessControlList.values().map(v => s"'${v.toString}'").mkString(", ") + throw new ConfigException(s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported].") + } } - // column_options - task.getColumnOptions.forEach { (k: String, - _) => - val c = schema.lookupColumn(k) - if (!c.getType.getName.equals("timestamp")) throw new ConfigException(s"column:$k is not 'timestamp' type.") + override def resume(taskSource: TaskSource, + schema: Schema, + taskCount: Int, + control: OutputPlugin.Control): ConfigDiff = + { + throw new UnsupportedOperationException("s3_parquet output plugin does not support resuming") } - // canned_acl - CannedAccessControlList.values().find(v => v.toString.equals(task.getCannedAclString)) match { - case Some(v) => task.setCannedAcl(v) - case None => - val unsupported: String = task.getCannedAclString - val supported: String = CannedAccessControlList.values().map(v => s"'${v.toString}'").mkString(", ") - throw new ConfigException(s"'$unsupported' is unsupported: `canned_acl` must be one of [$supported].") + override def cleanup(taskSource: TaskSource, + schema: Schema, + taskCount: Int, + successTaskReports: JList[TaskReport]): Unit = + { + successTaskReports.forEach { tr => + logger.info( + s"Created: s3://${tr.get(classOf[String], "bucket")}/${tr.get(classOf[String], "key")}, " + + s"version_id: ${tr.get(classOf[String], "version_id", null)}, " + + s"etag: ${tr.get(classOf[String], "etag", null)}") + } } - } - override def resume(taskSource: TaskSource, + override def open(taskSource: TaskSource, schema: Schema, - taskCount: Int, - control: OutputPlugin.Control): ConfigDiff = { - throw new UnsupportedOperationException("s3_parquet output plugin does not support resuming") - } - - override def cleanup(taskSource: TaskSource, - schema: Schema, - taskCount: Int, - successTaskReports: JList[TaskReport]): Unit = { - successTaskReports.forEach { tr => - logger.info( - s"Created: s3://${tr.get(classOf[String], "bucket")}/${tr.get(classOf[String], "key")}, " - + s"version_id: ${tr.get(classOf[String], "version_id", null)}, " - + s"etag: ${tr.get(classOf[String], "etag", null)}") + taskIndex: Int): TransactionalPageOutput = + { + val task = taskSource.loadTask(classOf[PluginTask]) + val bufferDir: String = task.getBufferDir.orElse(Files.createTempDirectory("embulk-output-s3_parquet-").toString) + val bufferFile: String = Paths.get(bufferDir, s"embulk-output-s3_parquet-task-$taskIndex-0.parquet").toString + val destS3bucket: String = task.getBucket + val destS3Key: String = task.getPathPrefix + String.format(task.getSequenceFormat, taskIndex: Integer, 0: Integer) + task.getFileExt + + + val pageReader: PageReader = new PageReader(schema) + val aws: Aws = Aws(task) + val timestampFormatters: Seq[TimestampFormatter] = Timestamps.newTimestampColumnFormatters(task, schema, task.getColumnOptions).toSeq + val parquetWriter: ParquetWriter[PageReader] = ParquetFileWriter.builder() + .withPath(bufferFile) + .withSchema(schema) + .withTimestampFormatters(timestampFormatters) + .withCompressionCodec(task.getCompressionCodec) + .withDictionaryEncoding(task.getEnableDictionaryEncoding.orElse(ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED)) + .withDictionaryPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE)) + .withMaxPaddingSize(task.getMaxPaddingSize.orElse(ParquetWriter.MAX_PADDING_SIZE_DEFAULT)) + .withPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE)) + .withRowGroupSize(task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE)) + .withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED) + .withWriteMode(org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE) + .withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION) + .build() + + logger.info(s"Local Buffer File: $bufferFile, Destination: s3://$destS3bucket/$destS3Key") + + S3ParquetPageOutput(bufferFile, pageReader, parquetWriter, aws, destS3bucket, destS3Key) } - } - - override def open(taskSource: TaskSource, - schema: Schema, - taskIndex: Int): TransactionalPageOutput = { - val task = taskSource.loadTask(classOf[PluginTask]) - val bufferDir: String = task.getBufferDir.orElse(Files.createTempDirectory("embulk-output-s3_parquet-").toString) - val bufferFile: String = Paths.get(bufferDir, s"embulk-output-s3_parquet-task-$taskIndex-0.parquet").toString - val destS3bucket: String = task.getBucket - val destS3Key: String = task.getPathPrefix + String.format(task.getSequenceFormat, taskIndex: Integer, 0: Integer) + task.getFileExt - - - val pageReader: PageReader = new PageReader(schema) - val aws: Aws = Aws(task) - val timestampFormatters: Seq[TimestampFormatter] = Timestamps.newTimestampColumnFormatters(task, schema, task.getColumnOptions) - val parquetWriter: ParquetWriter[PageReader] = ParquetFileWriter.builder() - .withPath(bufferFile) - .withSchema(schema) - .withTimestampFormatters(timestampFormatters) - .withCompressionCodec(task.getCompressionCodec) - .withDictionaryEncoding(task.getEnableDictionaryEncoding.orElse(ParquetProperties.DEFAULT_IS_DICTIONARY_ENABLED)) - .withDictionaryPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_DICTIONARY_PAGE_SIZE)) - .withMaxPaddingSize(task.getMaxPaddingSize.orElse(ParquetWriter.MAX_PADDING_SIZE_DEFAULT)) - .withPageSize(task.getPageSize.orElse(ParquetProperties.DEFAULT_PAGE_SIZE)) - .withRowGroupSize(task.getBlockSize.orElse(ParquetWriter.DEFAULT_BLOCK_SIZE)) - .withValidation(ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED) - .withWriteMode(org.apache.parquet.hadoop.ParquetFileWriter.Mode.CREATE) - .withWriterVersion(ParquetProperties.DEFAULT_WRITER_VERSION) - .build() - - logger.info(s"Local Buffer File: $bufferFile, Destination: s3://$destS3bucket/$destS3Key") - - S3ParquetPageOutput(bufferFile, pageReader, parquetWriter, aws, destS3bucket, destS3Key) - } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala index 8ad16e4..e3e0776 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/S3ParquetPageOutput.scala @@ -11,55 +11,63 @@ import org.embulk.config.TaskReport import org.embulk.output.s3_parquet.aws.Aws import org.embulk.spi.{Exec, Page, PageReader, TransactionalPageOutput} + case class S3ParquetPageOutput(outputLocalFile: String, reader: PageReader, writer: ParquetWriter[PageReader], aws: Aws, destBucket: String, destKey: String) - extends TransactionalPageOutput { + extends TransactionalPageOutput +{ - private var isClosed: Boolean = false + private var isClosed: Boolean = false - override def add(page: Page): Unit = { - reader.setPage(page) - while (reader.nextRecord()) { - writer.write(reader) + override def add(page: Page): Unit = + { + reader.setPage(page) + while (reader.nextRecord()) { + writer.write(reader) + } } - } - override def finish(): Unit = { - } + override def finish(): Unit = + { + } - override def close(): Unit = { - synchronized { - if (!isClosed) { - writer.close() - isClosed = true - } + override def close(): Unit = + { + synchronized { + if (!isClosed) { + writer.close() + isClosed = true + } + } } - } - override def abort(): Unit = { - close() - cleanup() - } + override def abort(): Unit = + { + close() + cleanup() + } - override def commit(): TaskReport = { - close() - val result: UploadResult = aws.withTransferManager { xfer: TransferManager => - val upload: Upload = xfer.upload(destBucket, destKey, new File(outputLocalFile)) - upload.waitForUploadResult() + override def commit(): TaskReport = + { + close() + val result: UploadResult = aws.withTransferManager { xfer: TransferManager => + val upload: Upload = xfer.upload(destBucket, destKey, new File(outputLocalFile)) + upload.waitForUploadResult() + } + cleanup() + Exec.newTaskReport() + .set("bucket", result.getBucketName) + .set("key", result.getKey) + .set("etag", result.getETag) + .set("version_id", result.getVersionId) } - cleanup() - Exec.newTaskReport() - .set("bucket", result.getBucketName) - .set("key", result.getKey) - .set("etag", result.getETag) - .set("version_id", result.getVersionId) - } - private def cleanup(): Unit = { - Files.delete(Paths.get(outputLocalFile)) - } + private def cleanup(): Unit = + { + Files.delete(Paths.get(outputLocalFile)) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala index e66e51d..a388aaa 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/Aws.scala @@ -2,44 +2,62 @@ package org.embulk.output.s3_parquet.aws import com.amazonaws.client.builder.AwsClientBuilder +import com.amazonaws.services.glue.{AWSGlue, AWSGlueClientBuilder} import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} import com.amazonaws.services.s3.transfer.{TransferManager, TransferManagerBuilder} -object Aws { - trait Task - extends AwsCredentials.Task - with AwsEndpointConfiguration.Task - with AwsClientConfiguration.Task - with AwsS3Configuration.Task +object Aws +{ - def apply(task: Task): Aws = new Aws(task) + trait Task + extends AwsCredentials.Task + with AwsEndpointConfiguration.Task + with AwsClientConfiguration.Task + with AwsS3Configuration.Task + + def apply(task: Task): Aws = + { + new Aws(task) + } } -class Aws(task: Aws.Task) { - - def withS3[A](f: AmazonS3 => A): A = { - val builder: AmazonS3ClientBuilder = AmazonS3ClientBuilder.standard() - AwsS3Configuration(task).configureAmazonS3ClientBuilder(builder) - val svc = createService(builder) - try f(svc) - finally svc.shutdown() - } - - def withTransferManager[A](f: TransferManager => A): A = { - withS3 { s3 => - val svc = TransferManagerBuilder.standard().withS3Client(s3).build() - try f(svc) - finally svc.shutdownNow(false) +class Aws(task: Aws.Task) +{ + + def withS3[A](f: AmazonS3 => A): A = + { + val builder: AmazonS3ClientBuilder = AmazonS3ClientBuilder.standard() + AwsS3Configuration(task).configureAmazonS3ClientBuilder(builder) + val svc = createService(builder) + try f(svc) + finally svc.shutdown() + } + + def withTransferManager[A](f: TransferManager => A): A = + { + withS3 { s3 => + val svc = TransferManagerBuilder.standard().withS3Client(s3).build() + try f(svc) + finally svc.shutdownNow(false) + } + } + + def withGlue[A](f: AWSGlue => A): A = + { + val builder: AWSGlueClientBuilder = AWSGlueClientBuilder.standard() + val svc = createService(builder) + try f(svc) + finally svc.shutdown() } - } - def createService[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): T = { - AwsEndpointConfiguration(task).configureAwsClientBuilder(builder) - AwsClientConfiguration(task).configureAwsClientBuilder(builder) - builder.setCredentials(AwsCredentials(task).createAwsCredentialsProvider) + def createService[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): T = + { + AwsEndpointConfiguration(task).configureAwsClientBuilder(builder) + AwsClientConfiguration(task).configureAwsClientBuilder(builder) + builder.setCredentials(AwsCredentials(task).createAwsCredentialsProvider) - builder.build() - } + builder.build() + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala index bf42618..6f0e975 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsClientConfiguration.scala @@ -8,27 +8,35 @@ import com.amazonaws.client.builder.AwsClientBuilder import org.embulk.config.{Config, ConfigDefault} import org.embulk.output.s3_parquet.aws.AwsClientConfiguration.Task -object AwsClientConfiguration { - trait Task { +object AwsClientConfiguration +{ - @Config("http_proxy") - @ConfigDefault("null") - def getHttpProxy: Optional[HttpProxy.Task] + trait Task + { - } + @Config("http_proxy") + @ConfigDefault("null") + def getHttpProxy: Optional[HttpProxy.Task] - def apply(task: Task): AwsClientConfiguration = new AwsClientConfiguration(task) + } + + def apply(task: Task): AwsClientConfiguration = + { + new AwsClientConfiguration(task) + } } -class AwsClientConfiguration(task: Task) { +class AwsClientConfiguration(task: Task) +{ - def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = { - task.getHttpProxy.ifPresent { v => - val cc = new ClientConfiguration - HttpProxy(v).configureClientConfiguration(cc) - builder.setClientConfiguration(cc) + def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = + { + task.getHttpProxy.ifPresent { v => + val cc = new ClientConfiguration + HttpProxy(v).configureClientConfiguration(cc) + builder.setClientConfiguration(cc) + } } - } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala index 463d2d0..19f823d 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsCredentials.scala @@ -9,120 +9,129 @@ import org.embulk.config.{Config, ConfigDefault, ConfigException} import org.embulk.output.s3_parquet.aws.AwsCredentials.Task import org.embulk.spi.unit.LocalFile -object AwsCredentials { - trait Task { +object AwsCredentials +{ - @Config("auth_method") - @ConfigDefault("\"default\"") - def getAuthMethod: String + trait Task + { - @Config("access_key_id") - @ConfigDefault("null") - def getAccessKeyId: Optional[String] + @Config("auth_method") + @ConfigDefault("\"default\"") + def getAuthMethod: String - @Config("secret_access_key") - @ConfigDefault("null") - def getSecretAccessKey: Optional[String] + @Config("access_key_id") + @ConfigDefault("null") + def getAccessKeyId: Optional[String] - @Config("session_token") - @ConfigDefault("null") - def getSessionToken: Optional[String] + @Config("secret_access_key") + @ConfigDefault("null") + def getSecretAccessKey: Optional[String] - @Config("profile_file") - @ConfigDefault("null") - def getProfileFile: Optional[LocalFile] + @Config("session_token") + @ConfigDefault("null") + def getSessionToken: Optional[String] - @Config("profile_name") - @ConfigDefault("\"default\"") - def getProfileName: String + @Config("profile_file") + @ConfigDefault("null") + def getProfileFile: Optional[LocalFile] - @Config("role_arn") - @ConfigDefault("null") - def getRoleArn: Optional[String] + @Config("profile_name") + @ConfigDefault("\"default\"") + def getProfileName: String - @Config("role_session_name") - @ConfigDefault("null") - def getRoleSessionName: Optional[String] + @Config("role_arn") + @ConfigDefault("null") + def getRoleArn: Optional[String] - @Config("role_external_id") - @ConfigDefault("null") - def getRoleExternalId: Optional[String] + @Config("role_session_name") + @ConfigDefault("null") + def getRoleSessionName: Optional[String] - @Config("role_session_duration_seconds") - @ConfigDefault("null") - def getRoleSessionDurationSeconds: Optional[Int] + @Config("role_external_id") + @ConfigDefault("null") + def getRoleExternalId: Optional[String] - @Config("scope_down_policy") - @ConfigDefault("null") - def getScopeDownPolicy: Optional[String] + @Config("role_session_duration_seconds") + @ConfigDefault("null") + def getRoleSessionDurationSeconds: Optional[Int] - } + @Config("scope_down_policy") + @ConfigDefault("null") + def getScopeDownPolicy: Optional[String] - def apply(task: Task): AwsCredentials = new AwsCredentials(task) -} - -class AwsCredentials(task: Task) { - - def createAwsCredentialsProvider: AWSCredentialsProvider = { - task.getAuthMethod match { - case "basic" => - new AWSStaticCredentialsProvider(new BasicAWSCredentials( - getRequiredOption(task.getAccessKeyId, "access_key_id"), - getRequiredOption(task.getAccessKeyId, "secret_access_key") - )) - - case "env" => - new EnvironmentVariableCredentialsProvider + } - case "instance" => - // NOTE: combination of InstanceProfileCredentialsProvider and ContainerCredentialsProvider - new EC2ContainerCredentialsProviderWrapper + def apply(task: Task): AwsCredentials = + { + new AwsCredentials(task) + } +} - case "profile" => - if (task.getProfileFile.isPresent) { - val pf: ProfilesConfigFile = new ProfilesConfigFile(task.getProfileFile.get().getFile) - new ProfileCredentialsProvider(pf, task.getProfileName) +class AwsCredentials(task: Task) +{ + + def createAwsCredentialsProvider: AWSCredentialsProvider = + { + task.getAuthMethod match { + case "basic" => + new AWSStaticCredentialsProvider(new BasicAWSCredentials( + getRequiredOption(task.getAccessKeyId, "access_key_id"), + getRequiredOption(task.getAccessKeyId, "secret_access_key") + )) + + case "env" => + new EnvironmentVariableCredentialsProvider + + case "instance" => + // NOTE: combination of InstanceProfileCredentialsProvider and ContainerCredentialsProvider + new EC2ContainerCredentialsProviderWrapper + + case "profile" => + if (task.getProfileFile.isPresent) { + val pf: ProfilesConfigFile = new ProfilesConfigFile(task.getProfileFile.get().getFile) + new ProfileCredentialsProvider(pf, task.getProfileName) + } + else new ProfileCredentialsProvider(task.getProfileName) + + case "properties" => + new SystemPropertiesCredentialsProvider + + case "anonymous" => + new AWSStaticCredentialsProvider(new AnonymousAWSCredentials) + + case "session" => + new AWSStaticCredentialsProvider(new BasicSessionCredentials( + getRequiredOption(task.getAccessKeyId, "access_key_id"), + getRequiredOption(task.getSecretAccessKey, "secret_access_key"), + getRequiredOption(task.getSessionToken, "session_token") + )) + + case "assume_role" => + // NOTE: Are http_proxy, endpoint, region required when assuming role? + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder( + getRequiredOption(task.getRoleArn, "role_arn"), + getRequiredOption(task.getRoleSessionName, "role_session_name") + ) + task.getRoleExternalId.ifPresent(v => builder.withExternalId(v)) + task.getRoleSessionDurationSeconds.ifPresent(v => builder.withRoleSessionDurationSeconds(v)) + task.getScopeDownPolicy.ifPresent(v => builder.withScopeDownPolicy(v)) + + builder.build() + + case "default" => + new DefaultAWSCredentialsProviderChain + + case am => + throw new ConfigException(s"'$am' is unsupported: `auth_method` must be one of ['basic', 'env', 'instance', 'profile', 'properties', 'anonymous', 'session', 'assume_role', 'default'].") } - else new ProfileCredentialsProvider(task.getProfileName) - - case "properties" => - new SystemPropertiesCredentialsProvider - - case "anonymous" => - new AWSStaticCredentialsProvider(new AnonymousAWSCredentials) - - case "session" => - new AWSStaticCredentialsProvider(new BasicSessionCredentials( - getRequiredOption(task.getAccessKeyId, "access_key_id"), - getRequiredOption(task.getSecretAccessKey, "secret_access_key"), - getRequiredOption(task.getSessionToken, "session_token") - )) - - case "assume_role" => - // NOTE: Are http_proxy, endpoint, region required when assuming role? - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder( - getRequiredOption(task.getRoleArn, "role_arn"), - getRequiredOption(task.getRoleSessionName, "role_session_name") - ) - task.getRoleExternalId.ifPresent(v => builder.withExternalId(v)) - task.getRoleSessionDurationSeconds.ifPresent(v => builder.withRoleSessionDurationSeconds(v)) - task.getScopeDownPolicy.ifPresent(v => builder.withScopeDownPolicy(v)) - - builder.build() - - case "default" => - new DefaultAWSCredentialsProviderChain - - case am => - throw new ConfigException(s"'$am' is unsupported: `auth_method` must be one of ['basic', 'env', 'instance', 'profile', 'properties', 'anonymous', 'session', 'assume_role', 'default'].") } - } - private def getRequiredOption[A](o: Optional[A], - name: String): A = { - o.orElseThrow(() => new ConfigException(s"`$name` must be set when `auth_method` is ${task.getAuthMethod}.")) - } + private def getRequiredOption[A](o: Optional[A], + name: String): A = + { + o.orElseThrow(() => new ConfigException(s"`$name` must be set when `auth_method` is ${task.getAuthMethod}.")) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala index 37b1f8a..e0303aa 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsEndpointConfiguration.scala @@ -11,39 +11,47 @@ import org.embulk.output.s3_parquet.aws.AwsEndpointConfiguration.Task import scala.util.Try -object AwsEndpointConfiguration { - trait Task { +object AwsEndpointConfiguration +{ - @Config("endpoint") - @ConfigDefault("null") - def getEndpoint: Optional[String] + trait Task + { - @Config("region") - @ConfigDefault("null") - def getRegion: Optional[String] + @Config("endpoint") + @ConfigDefault("null") + def getEndpoint: Optional[String] - } + @Config("region") + @ConfigDefault("null") + def getRegion: Optional[String] - def apply(task: Task): AwsEndpointConfiguration = new AwsEndpointConfiguration(task) -} - -class AwsEndpointConfiguration(task: Task) { - - def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = { - if (task.getRegion.isPresent && task.getEndpoint.isPresent) { - val ec = new EndpointConfiguration(task.getEndpoint.get, task.getRegion.get) - builder.setEndpointConfiguration(ec) } - else if (task.getRegion.isPresent && !task.getEndpoint.isPresent) { - builder.setRegion(task.getRegion.get) + + def apply(task: Task): AwsEndpointConfiguration = + { + new AwsEndpointConfiguration(task) } - else if (!task.getRegion.isPresent && task.getEndpoint.isPresent) { - val r: String = Try(new DefaultAwsRegionProviderChain().getRegion).getOrElse(Regions.DEFAULT_REGION.getName) - val e: String = task.getEndpoint.get - val ec = new EndpointConfiguration(e, r) - builder.setEndpointConfiguration(ec) +} + +class AwsEndpointConfiguration(task: Task) +{ + + def configureAwsClientBuilder[S <: AwsClientBuilder[S, T], T](builder: AwsClientBuilder[S, T]): Unit = + { + if (task.getRegion.isPresent && task.getEndpoint.isPresent) { + val ec = new EndpointConfiguration(task.getEndpoint.get, task.getRegion.get) + builder.setEndpointConfiguration(ec) + } + else if (task.getRegion.isPresent && !task.getEndpoint.isPresent) { + builder.setRegion(task.getRegion.get) + } + else if (!task.getRegion.isPresent && task.getEndpoint.isPresent) { + val r: String = Try(new DefaultAwsRegionProviderChain().getRegion).getOrElse(Regions.DEFAULT_REGION.getName) + val e: String = task.getEndpoint.get + val ec = new EndpointConfiguration(e, r) + builder.setEndpointConfiguration(ec) + } } - } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala index 2dd8b37..2e306f3 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/AwsS3Configuration.scala @@ -7,50 +7,58 @@ import com.amazonaws.services.s3.AmazonS3ClientBuilder import org.embulk.config.{Config, ConfigDefault} import org.embulk.output.s3_parquet.aws.AwsS3Configuration.Task + /* * These are advanced settings, so write no documentation. */ -object AwsS3Configuration { - trait Task { +object AwsS3Configuration +{ + trait Task + { - @Config("accelerate_mode_enabled") - @ConfigDefault("null") - def getAccelerateModeEnabled: Optional[Boolean] + @Config("accelerate_mode_enabled") + @ConfigDefault("null") + def getAccelerateModeEnabled: Optional[Boolean] - @Config("chunked_encoding_disabled") - @ConfigDefault("null") - def getChunkedEncodingDisabled: Optional[Boolean] + @Config("chunked_encoding_disabled") + @ConfigDefault("null") + def getChunkedEncodingDisabled: Optional[Boolean] - @Config("dualstack_enabled") - @ConfigDefault("null") - def getDualstackEnabled: Optional[Boolean] + @Config("dualstack_enabled") + @ConfigDefault("null") + def getDualstackEnabled: Optional[Boolean] - @Config("force_global_bucket_access_enabled") - @ConfigDefault("null") - def getForceGlobalBucketAccessEnabled: Optional[Boolean] + @Config("force_global_bucket_access_enabled") + @ConfigDefault("null") + def getForceGlobalBucketAccessEnabled: Optional[Boolean] - @Config("path_style_access_enabled") - @ConfigDefault("null") - def getPathStyleAccessEnabled: Optional[Boolean] + @Config("path_style_access_enabled") + @ConfigDefault("null") + def getPathStyleAccessEnabled: Optional[Boolean] - @Config("payload_signing_enabled") - @ConfigDefault("null") - def getPayloadSigningEnabled: Optional[Boolean] + @Config("payload_signing_enabled") + @ConfigDefault("null") + def getPayloadSigningEnabled: Optional[Boolean] - } + } - def apply(task: Task): AwsS3Configuration = new AwsS3Configuration(task) + def apply(task: Task): AwsS3Configuration = + { + new AwsS3Configuration(task) + } } -class AwsS3Configuration(task: Task) { +class AwsS3Configuration(task: Task) +{ - def configureAmazonS3ClientBuilder(builder: AmazonS3ClientBuilder): Unit = { - task.getAccelerateModeEnabled.ifPresent(v => builder.setAccelerateModeEnabled(v)) - task.getChunkedEncodingDisabled.ifPresent(v => builder.setChunkedEncodingDisabled(v)) - task.getDualstackEnabled.ifPresent(v => builder.setDualstackEnabled(v)) - task.getForceGlobalBucketAccessEnabled.ifPresent(v => builder.setForceGlobalBucketAccessEnabled(v)) - task.getPathStyleAccessEnabled.ifPresent(v => builder.setPathStyleAccessEnabled(v)) - task.getPayloadSigningEnabled.ifPresent(v => builder.setPayloadSigningEnabled(v)) - } + def configureAmazonS3ClientBuilder(builder: AmazonS3ClientBuilder): Unit = + { + task.getAccelerateModeEnabled.ifPresent(v => builder.setAccelerateModeEnabled(v)) + task.getChunkedEncodingDisabled.ifPresent(v => builder.setChunkedEncodingDisabled(v)) + task.getDualstackEnabled.ifPresent(v => builder.setDualstackEnabled(v)) + task.getForceGlobalBucketAccessEnabled.ifPresent(v => builder.setForceGlobalBucketAccessEnabled(v)) + task.getPathStyleAccessEnabled.ifPresent(v => builder.setPathStyleAccessEnabled(v)) + task.getPayloadSigningEnabled.ifPresent(v => builder.setPayloadSigningEnabled(v)) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala b/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala index 379aa33..4318538 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/aws/HttpProxy.scala @@ -7,50 +7,58 @@ import com.amazonaws.{ClientConfiguration, Protocol} import org.embulk.config.{Config, ConfigDefault, ConfigException} import org.embulk.output.s3_parquet.aws.HttpProxy.Task -object HttpProxy { - trait Task { +object HttpProxy +{ - @Config("host") - @ConfigDefault("null") - def getHost: Optional[String] + trait Task + { - @Config("port") - @ConfigDefault("null") - def getPort: Optional[Int] + @Config("host") + @ConfigDefault("null") + def getHost: Optional[String] - @Config("protocol") - @ConfigDefault("\"https\"") - def getProtocol: String + @Config("port") + @ConfigDefault("null") + def getPort: Optional[Int] - @Config("user") - @ConfigDefault("null") - def getUser: Optional[String] + @Config("protocol") + @ConfigDefault("\"https\"") + def getProtocol: String - @Config("password") - @ConfigDefault("null") - def getPassword: Optional[String] + @Config("user") + @ConfigDefault("null") + def getUser: Optional[String] - } + @Config("password") + @ConfigDefault("null") + def getPassword: Optional[String] - def apply(task: Task): HttpProxy = new HttpProxy(task) + } + + def apply(task: Task): HttpProxy = + { + new HttpProxy(task) + } } -class HttpProxy(task: Task) { +class HttpProxy(task: Task) +{ - def configureClientConfiguration(cc: ClientConfiguration): Unit = { - task.getHost.ifPresent(v => cc.setProxyHost(v)) - task.getPort.ifPresent(v => cc.setProxyPort(v)) + def configureClientConfiguration(cc: ClientConfiguration): Unit = + { + task.getHost.ifPresent(v => cc.setProxyHost(v)) + task.getPort.ifPresent(v => cc.setProxyPort(v)) - Protocol.values.find(p => p.name().equals(task.getProtocol)) match { - case Some(v) => - cc.setProtocol(v) - case None => - throw new ConfigException(s"'${task.getProtocol}' is unsupported: `protocol` must be one of [${Protocol.values.map(v => s"'$v'").mkString(", ")}].") - } + Protocol.values.find(p => p.name().equals(task.getProtocol)) match { + case Some(v) => + cc.setProtocol(v) + case None => + throw new ConfigException(s"'${task.getProtocol}' is unsupported: `protocol` must be one of [${Protocol.values.map(v => s"'$v'").mkString(", ")}].") + } - task.getUser.ifPresent(v => cc.setProxyUsername(v)) - task.getPassword.ifPresent(v => cc.setProxyPassword(v)) - } + task.getUser.ifPresent(v => cc.setProxyUsername(v)) + task.getPassword.ifPresent(v => cc.setProxyPassword(v)) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala index a6b48d0..31906d4 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/EmbulkMessageType.scala @@ -6,54 +6,74 @@ import org.apache.parquet.schema.{MessageType, OriginalType, PrimitiveType, Type import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.embulk.spi.{Column, ColumnVisitor, Schema} -object EmbulkMessageType { - def builder(): Builder = Builder() - - case class Builder(name: String = "embulk", - schema: Schema = Schema.builder().build()) { - - def withName(name: String): Builder = Builder(name = name, schema = schema) - - def withSchema(schema: Schema): Builder = Builder(name = name, schema = schema) - - def build(): MessageType = { - val builder: ImmutableList.Builder[Type] = ImmutableList.builder[Type]() - schema.visitColumns(EmbulkMessageTypeColumnVisitor(builder)) - new MessageType("embulk", builder.build()) +object EmbulkMessageType +{ + def builder(): Builder = + { + Builder() } - } + case class Builder(name: String = "embulk", + schema: Schema = Schema.builder().build()) + { - private case class EmbulkMessageTypeColumnVisitor(builder: ImmutableList.Builder[Type]) - extends ColumnVisitor { + def withName(name: String): Builder = + { + Builder(name = name, schema = schema) + } - override def booleanColumn(column: Column): Unit = { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BOOLEAN, column.getName)) - } + def withSchema(schema: Schema): Builder = + { + Builder(name = name, schema = schema) + } - override def longColumn(column: Column): Unit = { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.INT64, column.getName)) - } - - override def doubleColumn(column: Column): Unit = { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.DOUBLE, column.getName)) - } + def build(): MessageType = + { + val builder: ImmutableList.Builder[Type] = ImmutableList.builder[Type]() + schema.visitColumns(EmbulkMessageTypeColumnVisitor(builder)) + new MessageType("embulk", builder.build()) - override def stringColumn(column: Column): Unit = { - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) - } + } - override def timestampColumn(column: Column): Unit = { - // TODO: Support OriginalType.TIME* ? - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) } - override def jsonColumn(column: Column): Unit = { - // TODO: does this work? - builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) + private case class EmbulkMessageTypeColumnVisitor(builder: ImmutableList.Builder[Type]) + extends ColumnVisitor + { + + override def booleanColumn(column: Column): Unit = + { + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BOOLEAN, column.getName)) + } + + override def longColumn(column: Column): Unit = + { + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.INT64, column.getName)) + } + + override def doubleColumn(column: Column): Unit = + { + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.DOUBLE, column.getName)) + } + + override def stringColumn(column: Column): Unit = + { + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) + } + + override def timestampColumn(column: Column): Unit = + { + // TODO: Support OriginalType.TIME* ? + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) + } + + override def jsonColumn(column: Column): Unit = + { + // TODO: does this work? + builder.add(new PrimitiveType(Type.Repetition.OPTIONAL, PrimitiveTypeName.BINARY, column.getName, OriginalType.UTF8)) + } } - } } \ No newline at end of file diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala index e2f7b83..b140ad8 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriteSupport.scala @@ -9,25 +9,32 @@ import org.apache.parquet.schema.MessageType import org.embulk.spi.{PageReader, Schema} import org.embulk.spi.time.TimestampFormatter -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ -private[parquet] case class ParquetFileWriteSupport(schema: Schema, - timestampFormatters: Seq[TimestampFormatter]) - extends WriteSupport[PageReader] { - - private var currentParquetFileWriter: ParquetFileWriter = _ - - override def init(configuration: Configuration): WriteContext = { - val messageType: MessageType = EmbulkMessageType.builder() - .withSchema(schema) - .build() - val metadata: Map[String, String] = Map.empty // NOTE: When is this used? - new WriteContext(messageType, metadata.asJava) - } - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - currentParquetFileWriter = ParquetFileWriter(recordConsumer, schema, timestampFormatters) - } - - override def write(record: PageReader): Unit = currentParquetFileWriter.write(record) +private[parquet] case class ParquetFileWriteSupport(schema: Schema, + timestampFormatters: Seq[TimestampFormatter]) + extends WriteSupport[PageReader] +{ + + private var currentParquetFileWriter: ParquetFileWriter = _ + + override def init(configuration: Configuration): WriteContext = + { + val messageType: MessageType = EmbulkMessageType.builder() + .withSchema(schema) + .build() + val metadata: Map[String, String] = Map.empty // NOTE: When is this used? + new WriteContext(messageType, metadata.asJava) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = + { + currentParquetFileWriter = ParquetFileWriter(recordConsumer, schema, timestampFormatters) + } + + override def write(record: PageReader): Unit = + { + currentParquetFileWriter.write(record) + } } diff --git a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala index 772c0a3..0d0dd26 100644 --- a/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala +++ b/src/main/scala/org/embulk/output/s3_parquet/parquet/ParquetFileWriter.scala @@ -9,117 +9,151 @@ import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.embulk.spi.{Column, ColumnVisitor, PageReader, Schema} import org.embulk.spi.time.TimestampFormatter -object ParquetFileWriter { - case class Builder(path: Path = null, - schema: Schema = null, - timestampFormatters: Seq[TimestampFormatter] = null) - extends ParquetWriter.Builder[PageReader, Builder](path) { +object ParquetFileWriter +{ - def withPath(path: Path): Builder = copy(path = path) + case class Builder(path: Path = null, + schema: Schema = null, + timestampFormatters: Seq[TimestampFormatter] = null) + extends ParquetWriter.Builder[PageReader, Builder](path) + { - def withPath(pathString: String): Builder = copy(path = new Path(pathString)) + def withPath(path: Path): Builder = + { + copy(path = path) + } - def withSchema(schema: Schema): Builder = copy(schema = schema) + def withPath(pathString: String): Builder = + { + copy(path = new Path(pathString)) + } - def withTimestampFormatters(timestampFormatters: Seq[TimestampFormatter]): Builder = copy(timestampFormatters = timestampFormatters) + def withSchema(schema: Schema): Builder = + { + copy(schema = schema) + } - override def self(): Builder = this + def withTimestampFormatters(timestampFormatters: Seq[TimestampFormatter]): Builder = + { + copy(timestampFormatters = timestampFormatters) + } - override def getWriteSupport(conf: Configuration): WriteSupport[PageReader] = { - ParquetFileWriteSupport(schema, timestampFormatters) + override def self(): Builder = + { + this + } + + override def getWriteSupport(conf: Configuration): WriteSupport[PageReader] = + { + ParquetFileWriteSupport(schema, timestampFormatters) + } } - } - def builder(): Builder = Builder() + def builder(): Builder = + { + Builder() + } } private[parquet] case class ParquetFileWriter(recordConsumer: RecordConsumer, - schema: Schema, - timestampFormatters: Seq[TimestampFormatter]) { - - def write(record: PageReader): Unit = { - recordConsumer.startMessage() - writeRecord(record) - recordConsumer.endMessage() - } - - private def writeRecord(record: PageReader): Unit = { - - schema.visitColumns(new ColumnVisitor() { - - override def booleanColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addBoolean(record.getBoolean(column)) - }) - }) - } - - override def longColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addLong(record.getLong(column)) - }) - }) - } + schema: Schema, + timestampFormatters: Seq[TimestampFormatter]) +{ + + def write(record: PageReader): Unit = + { + recordConsumer.startMessage() + writeRecord(record) + recordConsumer.endMessage() + } - override def doubleColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - recordConsumer.addDouble(record.getDouble(column)) - }) - }) - } + private def writeRecord(record: PageReader): Unit = + { + + schema.visitColumns(new ColumnVisitor() + { + + override def booleanColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addBoolean(record.getBoolean(column)) + }) + }) + } + + override def longColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addLong(record.getLong(column)) + }) + }) + } + + override def doubleColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + recordConsumer.addDouble(record.getDouble(column)) + }) + }) + } + + override def stringColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + val bin = Binary.fromString(record.getString(column)) + recordConsumer.addBinary(bin) + }) + }) + } + + override def timestampColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + // TODO: is a correct way to convert for parquet ? + val t = record.getTimestamp(column) + val ft = timestampFormatters(column.getIndex).format(t) + val bin = Binary.fromString(ft) + recordConsumer.addBinary(bin) + }) + }) + } + + override def jsonColumn(column: Column): Unit = + { + nullOr(column, { + withWriteFieldContext(column, { + // TODO: is a correct way to convert for parquet ? + val msgPack = record.getJson(column) + val bin = Binary.fromString(msgPack.toJson) + recordConsumer.addBinary(bin) + }) + }) + } + + private def nullOr(column: Column, + f: => Unit): Unit = + { + if (!record.isNull(column)) f + } + + private def withWriteFieldContext(column: Column, + f: => Unit): Unit = + { + recordConsumer.startField(column.getName, column.getIndex) + f + recordConsumer.endField(column.getName, column.getIndex) + } - override def stringColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - val bin = Binary.fromString(record.getString(column)) - recordConsumer.addBinary(bin) - }) }) - } - override def timestampColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - // TODO: is a correct way to convert for parquet ? - val t = record.getTimestamp(column) - val ft = timestampFormatters(column.getIndex).format(t) - val bin = Binary.fromString(ft) - recordConsumer.addBinary(bin) - }) - }) - } - - override def jsonColumn(column: Column): Unit = { - nullOr(column, { - withWriteFieldContext(column, { - // TODO: is a correct way to convert for parquet ? - val msgPack = record.getJson(column) - val bin = Binary.fromString(msgPack.toJson) - recordConsumer.addBinary(bin) - }) - }) - } - - private def nullOr(column: Column, - f: => Unit): Unit = { - if (!record.isNull(column)) f - } - - private def withWriteFieldContext(column: Column, - f: => Unit): Unit = { - recordConsumer.startField(column.getName, column.getIndex) - f - recordConsumer.endField(column.getName, column.getIndex) - } - - }) - - } + } } \ No newline at end of file diff --git a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala index 0f6bfb2..ee7f5c8 100644 --- a/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala +++ b/src/test/scala/org/embulk/output/s3_parquet/TestS3ParquetOutputPlugin.scala @@ -1,8 +1,8 @@ package org.embulk.output.s3_parquet -import java.io.{File, PrintWriter} -import java.nio.file.{FileSystems, Path} +import java.io.File +import java.nio.file.FileSystems import cloud.localstack.{DockerTestUtils, Localstack, TestUtils} import cloud.localstack.docker.LocalstackDocker @@ -21,120 +21,129 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, DiagrammedAssertions, F import org.scalatest.junit.JUnitRunner import scala.annotation.meta.getter -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ + @RunWith(classOf[JUnitRunner]) class TestS3ParquetOutputPlugin - extends FunSuite - with BeforeAndAfter - with BeforeAndAfterAll - with DiagrammedAssertions { - - val RESOURCE_NAME_PREFIX: String = "org/embulk/output/s3_parquet/" - val BUCKET_NAME: String = "my-bucket" - - val LOCALSTACK_DOCKER: LocalstackDocker = LocalstackDocker.INSTANCE - - override protected def beforeAll(): Unit = { - Localstack.teardownInfrastructure() - LOCALSTACK_DOCKER.startup(LocalstackDockerConfiguration.DEFAULT) - super.beforeAll() - } - - override protected def afterAll(): Unit = { - LOCALSTACK_DOCKER.stop() - super.afterAll() - } - - @(Rule@getter) - val embulk: TestingEmbulk = TestingEmbulk.builder() - .registerPlugin(classOf[OutputPlugin], "s3_parquet", classOf[S3ParquetOutputPlugin]) - .build() - - before { - DockerTestUtils.getClientS3.createBucket(BUCKET_NAME) - } - - def defaultOutConfig(): ConfigSource = { - embulk.newConfig() - .set("type", "s3_parquet") - .set("endpoint", "http://localhost:4572") // See https://github.com/localstack/localstack#overview - .set("bucket", BUCKET_NAME) - .set("path_prefix", "path/to/p") - .set("auth_method", "basic") - .set("access_key_id", TestUtils.TEST_ACCESS_KEY) - .set("secret_access_key", TestUtils.TEST_SECRET_KEY) - .set("path_style_access_enabled", true) - .set("default_timezone", "Asia/Tokyo") - } - - - test("first test") { - val inPath = toPath("in1.csv") - val outConfig = defaultOutConfig() - - val result: TestingEmbulk.RunResult = embulk.runOutput(outConfig, inPath) - - - val outRecords: Seq[Map[String, String]] = result.getOutputTaskReports.asScala.map { tr => - val b = tr.get(classOf[String], "bucket") - val k = tr.get(classOf[String], "key") - readParquetFile(b, k) - }.foldLeft(Seq[Map[String, String]]()) { (merged, - records) => - merged ++ records + extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with DiagrammedAssertions +{ + + val RESOURCE_NAME_PREFIX: String = "org/embulk/output/s3_parquet/" + val BUCKET_NAME: String = "my-bucket" + + val LOCALSTACK_DOCKER: LocalstackDocker = LocalstackDocker.INSTANCE + + override protected def beforeAll(): Unit = + { + Localstack.teardownInfrastructure() + LOCALSTACK_DOCKER.startup(LocalstackDockerConfiguration.DEFAULT) + super.beforeAll() } - val inRecords: Seq[Seq[String]] = EmbulkTests.readResource(RESOURCE_NAME_PREFIX + "out1.tsv") - .stripLineEnd - .split("\n") - .map(record => record.split("\t").toSeq) + override protected def afterAll(): Unit = + { + LOCALSTACK_DOCKER.stop() + super.afterAll() + } - inRecords.zipWithIndex.foreach { - case (record, recordIndex) => - 0.to(5).foreach { columnIndex => - val columnName = s"c$columnIndex" - val inData: String = inRecords(recordIndex)(columnIndex) - val outData: String = outRecords(recordIndex).getOrElse(columnName, "") + @(Rule@getter) + val embulk: TestingEmbulk = TestingEmbulk.builder() + .registerPlugin(classOf[OutputPlugin], "s3_parquet", classOf[S3ParquetOutputPlugin]) + .build() - assert(outData === inData, s"record: $recordIndex, column: $columnName") - } + before { + DockerTestUtils.getClientS3.createBucket(BUCKET_NAME) } - } - - def readParquetFile(bucket: String, - key: String): Seq[Map[String, String]] = { - val xfer = TransferManagerBuilder.standard() - .withS3Client(DockerTestUtils.getClientS3) - .build() - val createdParquetFile = embulk.createTempFile("in") - try xfer.download(bucket, key, createdParquetFile.toFile).waitForCompletion() - finally xfer.shutdownNow() - - val reader: ParquetReader[SimpleRecord] = ParquetReader - .builder(new SimpleReadSupport(), new HadoopPath(createdParquetFile.toString)) - .build() - - def read(reader: ParquetReader[SimpleRecord], - records: Seq[Map[String, String]] = Seq()): Seq[Map[String, String]] = { - val simpleRecord: SimpleRecord = reader.read() - if (simpleRecord != null) { - val r: Map[String, String] = simpleRecord.getValues.asScala.map(v => v.getName -> v.getValue.toString).toMap - return read(reader, records :+ r) - } - records + + def defaultOutConfig(): ConfigSource = + { + embulk.newConfig() + .set("type", "s3_parquet") + .set("endpoint", "http://localhost:4572") // See https://github.com/localstack/localstack#overview + .set("bucket", BUCKET_NAME) + .set("path_prefix", "path/to/p") + .set("auth_method", "basic") + .set("access_key_id", TestUtils.TEST_ACCESS_KEY) + .set("secret_access_key", TestUtils.TEST_SECRET_KEY) + .set("path_style_access_enabled", true) + .set("default_timezone", "Asia/Tokyo") } - try read(reader) - finally { - reader.close() + test("first test") { + val inPath = toPath("in1.csv") + val outConfig = defaultOutConfig() + + val result: TestingEmbulk.RunResult = embulk.runOutput(outConfig, inPath) + + + val outRecords: Seq[Map[String, String]] = result.getOutputTaskReports.asScala.map { tr => + val b = tr.get(classOf[String], "bucket") + val k = tr.get(classOf[String], "key") + readParquetFile(b, k) + }.foldLeft(Seq[Map[String, String]]()) { (merged, + records) => + merged ++ records + } + + val inRecords: Seq[Seq[String]] = EmbulkTests.readResource(RESOURCE_NAME_PREFIX + "out1.tsv") + .stripLineEnd + .split("\n") + .map(record => record.split("\t").toSeq) + .toSeq + + inRecords.zipWithIndex.foreach { + case (record, recordIndex) => + 0.to(5).foreach { columnIndex => + val columnName = s"c$columnIndex" + val inData: String = inRecords(recordIndex)(columnIndex) + val outData: String = outRecords(recordIndex).getOrElse(columnName, "") + + assert(outData === inData, s"record: $recordIndex, column: $columnName") + } + } + } + + def readParquetFile(bucket: String, + key: String): Seq[Map[String, String]] = + { + val xfer = TransferManagerBuilder.standard() + .withS3Client(DockerTestUtils.getClientS3) + .build() + val createdParquetFile = embulk.createTempFile("in") + try xfer.download(bucket, key, createdParquetFile.toFile).waitForCompletion() + finally xfer.shutdownNow() + + val reader: ParquetReader[SimpleRecord] = ParquetReader + .builder(new SimpleReadSupport(), new HadoopPath(createdParquetFile.toString)) + .build() + + def read(reader: ParquetReader[SimpleRecord], + records: Seq[Map[String, String]] = Seq()): Seq[Map[String, String]] = + { + val simpleRecord: SimpleRecord = reader.read() + if (simpleRecord != null) { + val r: Map[String, String] = simpleRecord.getValues.asScala.map(v => v.getName -> v.getValue.toString).toMap + return read(reader, records :+ r) + } + records + } + + try read(reader) + finally { + reader.close() + + } } - } - private def toPath(fileName: String) = { - val url = Resources.getResource(RESOURCE_NAME_PREFIX + fileName) - FileSystems.getDefault.getPath(new File(url.toURI).getAbsolutePath) - } + private def toPath(fileName: String) = + { + val url = Resources.getResource(RESOURCE_NAME_PREFIX + fileName) + FileSystems.getDefault.getPath(new File(url.toURI).getAbsolutePath) + } }