diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/Directive.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/Directive.java index 2a90dbb78..46f53b8b8 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/Directive.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/Directive.java @@ -141,35 +141,11 @@ default List getCountMetrics() { * given a field name and value (Java object) * @param output output data generated by directive after its execution * @return {@link Schema} of the output data - * @throws SchemaGenerationException when there is a problem with generating schema for a Java object - * @implNote By default, the implementation has the following steps: - *
    - *
  1. Get map of renamed fields and map of type changed fields
  2. - *
  3. Iterate over list of fields in each output row
  4. - *
  5. If output field is renamed, get original name.
  6. - *
  7. Check if it is in map of type changed fields or if it exists in inputSchema, then use that schema
  8. - *
  9. If it is a new field, generate the schema using provided {@link SchemaFieldGenerator}
  10. - *
  11. Finally, add the created field to set of output fields
  12. - *
- */ + * @implNote By default, the implementation returns a null (no-op) + * */ @Override - default Schema getOutputSchema(Schema inputSchema, SchemaFieldGenerator generator, List output) - throws SchemaGenerationException { - Set outputFields = new HashSet<>(); // Use a set to avoid duplicate fields - Map renamedFields = getRenamedFields(inputSchema); - Map typeChangedFields = getTypeChangedFields(inputSchema); - for (Row row : output) { - for (Pair field : row.getFields()) { - String originalName = renamedFields.getOrDefault(field.getFirst(), field.getFirst()); - // If field's datatype was changed or field's already exists in inputSchema, use that schema - Schema schema = typeChangedFields.containsKey(originalName) ? typeChangedFields.get(originalName) - : (inputSchema.getField(originalName) != null ? inputSchema.getField(originalName).getSchema() : null); - // If schema exists, use it, otherwise generate a new one - Schema.Field outputField = (schema != null) ? Schema.Field.of(field.getFirst(), schema) - : generator.generateSchemaField(field.getFirst(), field.getSecond()); - outputFields.add(outputField); - } - } - return Schema.recordOf(outputFields); + default Schema getOutputSchema(Schema inputSchema, SchemaFieldGenerator generator, List output) { + // no op + return null; } } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 8421a79d7..6e6e6a265 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -24,12 +24,14 @@ import io.cdap.wrangler.api.ErrorRowException; import io.cdap.wrangler.api.Executor; import io.cdap.wrangler.api.ExecutorContext; +import io.cdap.wrangler.api.Pair; import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.RecipeParser; import io.cdap.wrangler.api.RecipePipeline; import io.cdap.wrangler.api.ReportErrorAndProceed; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; +import io.cdap.wrangler.api.schema.SchemaFieldGenerator; import io.cdap.wrangler.api.schema.SchemaGenerationException; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; @@ -40,6 +42,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import javax.annotation.Nullable; @@ -49,9 +52,11 @@ public final class RecipePipelineExecutor implements RecipePipeline { private static final Logger LOG = LoggerFactory.getLogger(RecipePipelineExecutor.class); + private static final String TRANSIENT_STORE_OUTPUT_SCHEMA_KEY = "ws_schema_output"; private final ErrorRecordCollector collector = new ErrorRecordCollector(); private final RecordConvertor convertor = new RecordConvertor(); + private final SchemaFieldGenerator schemaFieldGenerator = new SchemaConverter(); private final RecipeParser recipeParser; private final ExecutorContext context; private List directives; @@ -116,8 +121,8 @@ public List execute(List rows) throws RecipeException { context.getTransientStore().reset(TransientVariableScope.LOCAL); } - // Get input schema from transientStore - Schema schema = context.getEnvironment() != null && + // Initialize outputSchema with input schema from TransientStore if running in service env (design-time) + Schema outputSchema = context.getEnvironment() != null && context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) ? context.getTransientStore().get("ws_schema") : null; @@ -131,8 +136,8 @@ public List execute(List rows) throws RecipeException { if (cumulativeRows.size() < 1) { break; } - if (schema != null) { - schema = directive.getOutputSchema(schema, new SchemaConverter(), cumulativeRows); + if (outputSchema != null) { + outputSchema = getOrGenerateOutputSchema(outputSchema, rows, directive); } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); @@ -144,10 +149,10 @@ public List execute(List rows) throws RecipeException { throw new RecipeException("Problem generating the schema for field: " + e.getMessage(), e); } } - if (schema != null) { - Schema previousUpdatedSchema = context.getTransientStore().get("ws_schema_updated"); - context.getTransientStore().set(TransientVariableScope.GLOBAL, "ws_schema_updated", - getSchemaUnion(previousUpdatedSchema, schema)); + if (outputSchema != null) { + Schema previousOutputSchema = context.getTransientStore().get(TRANSIENT_STORE_OUTPUT_SCHEMA_KEY); + context.getTransientStore().set(TransientVariableScope.GLOBAL, TRANSIENT_STORE_OUTPUT_SCHEMA_KEY, + getSchemaUnion(previousOutputSchema, outputSchema)); } results.addAll(cumulativeRows); } catch (ErrorRowException e) { @@ -186,4 +191,46 @@ private Schema getSchemaUnion(Schema first, Schema second) { firstFields.addAll(second.getFields()); return Schema.recordOf(firstFields); } + + /** + * Method to get or generate the output schema for a directive. Generation is done using the following steps: + *
    + *
  1. Get map of renamed fields and map of type changed fields returned by directive
  2. + *
  3. Iterate over list of fields in each output row
  4. + *
  5. If output field is renamed, get original name.
  6. + *
  7. Check if it is in map of type changed fields or if it exists in inputSchema, then use that schema
  8. + *
  9. If it is a new field, generate the schema using provided {@link SchemaFieldGenerator}
  10. + *
  11. Finally, add the created field to set of output fields
  12. + *
+ * @param inputSchema input schema before applying this directive's transformation + * @param output output rows after applying the directive's transformation + * @param directive directive to get/generate output schema for + * @return {@link Schema} output after applying specified directive + * @throws SchemaGenerationException if there is an issue with generating {@link Schema} for a field + * + */ + private Schema getOrGenerateOutputSchema(Schema inputSchema, List output, + Executor, List> directive) + throws SchemaGenerationException { + if (directive.getOutputSchema(inputSchema, schemaFieldGenerator, output) != null) { + return directive.getOutputSchema(inputSchema, schemaFieldGenerator, output); + } + // Schema generation + Set outputFields = new HashSet<>(); // Use a set to avoid duplicate fields + Map renamedFields = directive.getRenamedFields(inputSchema); + Map typeChangedFields = directive.getTypeChangedFields(inputSchema); + for (Row row : output) { + for (Pair field : row.getFields()) { + String originalName = renamedFields.getOrDefault(field.getFirst(), field.getFirst()); + // If field's datatype was changed or field's already exists in inputSchema, use that schema + Schema schema = typeChangedFields.containsKey(originalName) ? typeChangedFields.get(originalName) + : (inputSchema.getField(originalName) != null ? inputSchema.getField(originalName).getSchema() : null); + // If schema exists, use it, otherwise generate a new one + Schema.Field outputField = (schema != null) ? Schema.Field.of(field.getFirst(), schema) + : schemaFieldGenerator.generateSchemaField(field.getFirst(), field.getSecond()); + outputFields.add(outputField); + } + } + return Schema.recordOf(outputFields); + } }