diff --git a/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java b/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java index a7e201e..bddc9cc 100644 --- a/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java +++ b/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java @@ -6,6 +6,7 @@ import java.util.*; import org.embulk.config.ConfigDiff; import org.embulk.config.ConfigException; +import org.embulk.config.ConfigSource; import org.embulk.config.TaskSource; import org.embulk.output.jdbc.*; import org.embulk.output.snowflake.PrivateKeyReader; @@ -22,8 +23,6 @@ import org.embulk.util.config.ConfigDefault; public class SnowflakeOutputPlugin extends AbstractJdbcOutputPlugin { - private StageIdentifier stageIdentifier; - public interface SnowflakePluginTask extends PluginTask { @Config("driver_path") @ConfigDefault("null") @@ -65,6 +64,10 @@ public interface SnowflakePluginTask extends PluginTask { @Config("empty_field_as_null") @ConfigDefault("true") public boolean getEmtpyFieldAsNull(); + + @Config("delete_stage_on_error") + @ConfigDefault("false") + public boolean getDeleteStageOnError(); } @Override @@ -130,25 +133,39 @@ protected JdbcOutputConnector getConnector(PluginTask task, boolean retryableMet } @Override - public ConfigDiff resume( - TaskSource taskSource, Schema schema, int taskCount, OutputPlugin.Control control) { - throw new UnsupportedOperationException("snowflake output plugin does not support resuming"); - } - - @Override - protected void doCommit(JdbcOutputConnection con, PluginTask task, int taskCount) - throws SQLException { - super.doCommit(con, task, taskCount); - SnowflakeOutputConnection snowflakeCon = (SnowflakeOutputConnection) con; - + public ConfigDiff transaction( + ConfigSource config, Schema schema, int taskCount, OutputPlugin.Control control) { + PluginTask task = CONFIG_MAPPER.map(config, this.getTaskClass()); SnowflakePluginTask t = (SnowflakePluginTask) task; - if (this.stageIdentifier == null) { - this.stageIdentifier = StageIdentifierHolder.getStageIdentifier(t); + StageIdentifier stageIdentifier = StageIdentifierHolder.getStageIdentifier(t); + ConfigDiff configDiff; + SnowflakeOutputConnection snowflakeCon = null; + + try { + snowflakeCon = (SnowflakeOutputConnection) getConnector(task, true).connect(true); + snowflakeCon.runCreateStage(stageIdentifier); + configDiff = super.transaction(config, schema, taskCount, control); + if (t.getDeleteStage()) { + snowflakeCon.runDropStage(stageIdentifier); + } + } catch (Exception e) { + if (t.getDeleteStage() && t.getDeleteStageOnError()) { + try { + snowflakeCon.runDropStage(stageIdentifier); + } catch (SQLException ex) { + throw new RuntimeException(ex); + } + } + throw new RuntimeException(e); } - if (t.getDeleteStage()) { - snowflakeCon.runDropStage(this.stageIdentifier); - } + return configDiff; + } + + @Override + public ConfigDiff resume( + TaskSource taskSource, Schema schema, int taskCount, OutputPlugin.Control control) { + throw new UnsupportedOperationException("snowflake output plugin does not support resuming"); } @Override @@ -165,20 +182,11 @@ protected BatchInsert newBatchInsert(PluginTask task, Optional merg throw new UnsupportedOperationException( "Snowflake output plugin doesn't support 'merge_direct' mode."); } - - SnowflakePluginTask t = (SnowflakePluginTask) task; - // TODO: put some where executes once - if (this.stageIdentifier == null) { - SnowflakeOutputConnection snowflakeCon = - (SnowflakeOutputConnection) getConnector(task, true).connect(true); - this.stageIdentifier = StageIdentifierHolder.getStageIdentifier(t); - snowflakeCon.runCreateStage(this.stageIdentifier); - } SnowflakePluginTask pluginTask = (SnowflakePluginTask) task; return new SnowflakeCopyBatchInsert( getConnector(task, true), - this.stageIdentifier, + StageIdentifierHolder.getStageIdentifier(pluginTask), false, pluginTask.getMaxUploadRetries(), pluginTask.getEmtpyFieldAsNull()); diff --git a/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java b/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java index 18f8155..a1ebd8c 100644 --- a/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java +++ b/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java @@ -61,7 +61,6 @@ public SnowflakeCopyBatchInsert( @Override public void prepare(TableIdentifier loadTable, JdbcSchema insertSchema) throws SQLException { this.connection = (SnowflakeOutputConnection) connector.connect(true); - this.connection.runCreateStage(stageIdentifier); this.tableIdentifier = loadTable; } diff --git a/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java b/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java index 5bc161b..1168629 100644 --- a/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java +++ b/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java @@ -11,8 +11,12 @@ import org.embulk.output.jdbc.JdbcSchema; import org.embulk.output.jdbc.MergeConfig; import org.embulk.output.jdbc.TableIdentifier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class SnowflakeOutputConnection extends JdbcOutputConnection { + private final Logger logger = LoggerFactory.getLogger(SnowflakeOutputConnection.class); + public SnowflakeOutputConnection(Connection connection) throws SQLException { super(connection, null); } @@ -32,11 +36,13 @@ public void runCopy( public void runCreateStage(StageIdentifier stageIdentifier) throws SQLException { String sql = buildCreateStageSQL(stageIdentifier); runUpdate(sql); + logger.info("SQL: {}", sql); } public void runDropStage(StageIdentifier stageIdentifier) throws SQLException { String sql = buildDropStageSQL(stageIdentifier); runUpdate(sql); + logger.info("SQL: {}", sql); } public void runUploadFile(