From 3ae493bbf325b189de9f0669a7d3d96260e5bbb2 Mon Sep 17 00:00:00 2001 From: vanathig Date: Fri, 14 Jul 2023 10:48:36 +0530 Subject: [PATCH 1/6] Add basic schema handling to Wrangler --- .../java/io/cdap/wrangler/api/Executor.java | 17 +++- .../executor/RecipePipelineExecutor.java | 78 +++++++++++++++++++ .../wrangler/utils/TransientStoreKeys.java | 29 +++++++ .../cdap/wrangler/TestingPipelineContext.java | 4 +- .../java/io/cdap/wrangler/TestingRig.java | 4 +- .../executor/RecipePipelineExecutorTest.java | 70 +++++++++++++++++ .../directive/AbstractDirectiveHandler.java | 52 ++++++++----- .../service/directive/WorkspaceHandler.java | 22 ++++-- 8 files changed, 246 insertions(+), 30 deletions(-) create mode 100644 wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java index abc726295..c90374c39 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java @@ -16,9 +16,11 @@ package io.cdap.wrangler.api; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.api.annotations.PublicEvolving; import java.io.Serializable; +import javax.annotation.Nullable; /** * A interface defining the wrangle Executor in the wrangling {@link RecipePipeline}. @@ -80,5 +82,18 @@ O execute(I rows, ExecutorContext context) * correct at this phase of invocation. */ void destroy(); -} + /** + * This method is used to get the updated schema of the data after the directive's transformation has been applied. + * @implNote By default, returns a null and the schema is inferred from the data when necessary. + *

For consistent handling, override for directives that perform column renames, + * column data type changes or column additions with specific schemas.

+ * @param inputSchema input {@link Schema} of the data before transformation + * @return output {@link Schema} of the transformed data + */ + @Nullable + default Schema getOutputSchema(Schema inputSchema) { + // no op + return null; + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 7202ec90a..2fa103f8f 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -18,12 +18,14 @@ import io.cdap.cdap.api.data.format.StructuredRecord; import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.cdap.api.data.schema.Schema.Field; import io.cdap.wrangler.api.Directive; import io.cdap.wrangler.api.DirectiveExecutionException; import io.cdap.wrangler.api.ErrorRecord; import io.cdap.wrangler.api.ErrorRowException; import io.cdap.wrangler.api.Executor; import io.cdap.wrangler.api.ExecutorContext; +import io.cdap.wrangler.api.Pair; import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.RecipeParser; import io.cdap.wrangler.api.RecipePipeline; @@ -32,11 +34,16 @@ import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; +import io.cdap.wrangler.utils.SchemaConverter; +import io.cdap.wrangler.utils.TransientStoreKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import javax.annotation.Nullable; /** @@ -45,9 +52,11 @@ public final class RecipePipelineExecutor implements RecipePipeline { private static final Logger LOG = LoggerFactory.getLogger(RecipePipelineExecutor.class); + private static final String TEMP_SCHEMA_FIELD_NAME = "temporarySchemaField"; private final ErrorRecordCollector collector = new ErrorRecordCollector(); private final RecordConvertor convertor = new RecordConvertor(); + private final SchemaConverter generator = new SchemaConverter(); private final RecipeParser recipeParser; private final ExecutorContext context; private List directives; @@ -112,6 +121,12 @@ public List execute(List rows) throws RecipeException { context.getTransientStore().reset(TransientVariableScope.LOCAL); } + // Initialize schema with input schema from TransientStore if running in service env (design-time) / testing env + boolean designTime = context.getEnvironment() != null && + context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) || + context.getEnvironment().equals(ExecutorContext.Environment.TESTING); + Schema schema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; + List cumulativeRows = rows.subList(i, i + 1); directiveIndex = 0; try { @@ -122,14 +137,26 @@ public List execute(List rows) throws RecipeException { if (cumulativeRows.size() < 1) { break; } + if (designTime && schema != null) { + Schema directiveOutputSchema = directive.getOutputSchema(schema); + schema = directiveOutputSchema != null ? directiveOutputSchema + : generateOutputSchema(schema, cumulativeRows); + } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); collector .add(new ErrorRecord(rows.subList(i, i + 1).get(0), String.join(",", messages), e.getCode(), true)); cumulativeRows = new ArrayList<>(); break; + } catch (RecordConvertorException e) { + throw new RecipeException("Error while generating schema: " + e.getMessage(), e); } } + if (designTime && schema != null) { + Schema previousRowSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + schema = previousRowSchema != null ? getSchemaUnion(previousRowSchema, schema) : schema; + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, schema); + } results.addAll(cumulativeRows); } catch (ErrorRowException e) { messages.add(String.format("%s", e.getMessage())); @@ -161,4 +188,55 @@ private List getDirectives() throws RecipeException { } return directives; } + + private Schema generateOutputSchema(Schema inputSchema, List output) throws RecordConvertorException { + Map outputFieldMap = new LinkedHashMap<>(); + for (Row row : output) { + for (Pair rowField : row.getFields()) { + String fieldName = rowField.getFirst(); + Object fieldValue = rowField.getSecond(); + + Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; + Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? + generator.getSchema(fieldValue, fieldName) : null; + + if (generated != null) { + outputFieldMap.put(fieldName, generated); + } else if (existing != null) { + outputFieldMap.put(fieldName, existing); + } + } + } + List outputFields = outputFieldMap.entrySet().stream() + .map(e -> Schema.Field.of(e.getKey(), e.getValue())) + .collect(Collectors.toList()); + return Schema.recordOf("output", outputFields); + } + + // Checks whether the provided input schema is of valid type for given object + private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { + if (schema == null) { + return false; + } + Schema generated = generator.getSchema(value, TEMP_SCHEMA_FIELD_NAME); + generated = generated.isNullable() ? generated.getNonNullable() : generated; + schema = schema.isNullable() ? schema.getNonNullable() : schema; + return generated.getType().equals(schema.getType()); + } + + // Gets the union of fields in two schemas while maintaining insertion order and uniqueness of fields. If the same + // field exists with two different schemas, the second schema overwrites first one + private Schema getSchemaUnion(Schema first, Schema second) { + Map fieldMap = new LinkedHashMap<>(); + for (Field field : first.getFields()) { + fieldMap.put(field.getName(), field.getSchema()); + } + for (Field field : second.getFields()) { + fieldMap.put(field.getName(), field.getSchema()); + } + List outputFields = fieldMap.entrySet().stream() + .map(e -> Schema.Field.of(e.getKey(), e.getValue())) + .collect(Collectors.toList()); + return Schema.recordOf("union", outputFields); + } } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java new file mode 100644 index 000000000..d393e6656 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java @@ -0,0 +1,29 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.utils; + +/** + * TransientStoreKeys for storing Workspace schema in TransientStore + */ +public final class TransientStoreKeys { + public static final String INPUT_SCHEMA = "ws_input_schema"; + public static final String OUTPUT_SCHEMA = "ws_output_schema"; + + private TransientStoreKeys() { + throw new AssertionError("Cannot instantiate a static utility class."); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java index 53fd7c465..32b3614fd 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java @@ -34,13 +34,13 @@ * This class {@link TestingPipelineContext} is a runtime context that is provided for each * {@link Executor} execution. */ -class TestingPipelineContext implements ExecutorContext { +public class TestingPipelineContext implements ExecutorContext { private StageMetrics metrics; private String name; private TransientStore store; private Map properties; - TestingPipelineContext() { + public TestingPipelineContext() { properties = new HashedMap(); store = new DefaultTransientStore(); } diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java index 9a5712b5d..10a6da4e2 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java @@ -60,7 +60,7 @@ private TestingRig() { */ public static List execute(String[] recipe, List rows) throws RecipeException, DirectiveParseException, DirectiveLoadException { - return execute(recipe, rows, null); + return execute(recipe, rows, new TestingPipelineContext()); } public static List execute(String[] recipe, List rows, ExecutorContext context) @@ -83,7 +83,7 @@ public static List execute(String[] recipe, List rows, ExecutorContext */ public static Pair, List> executeWithErrors(String[] recipe, List rows) throws RecipeException, DirectiveParseException, DirectiveLoadException, DirectiveNotFoundException { - return executeWithErrors(recipe, rows, null); + return executeWithErrors(recipe, rows, new TestingPipelineContext()); } public static Pair, List> executeWithErrors(String[] recipe, List rows, ExecutorContext context) diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index 53f1b9e78..12cd96fe0 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -18,13 +18,21 @@ import io.cdap.cdap.api.data.format.StructuredRecord; import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.TestingPipelineContext; import io.cdap.wrangler.TestingRig; +import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.RecipePipeline; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientVariableScope; +import io.cdap.wrangler.utils.TransientStoreKeys; import org.junit.Assert; import org.junit.Test; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; /** * Tests {@link RecipePipelineExecutor}. @@ -96,4 +104,66 @@ public void testPipelineWithMoreSimpleTypes() throws Exception { Assert.assertEquals(1481666448L, record.get("timestamp").longValue()); Assert.assertEquals(186.66f, record.get("weight"), 0.0001f); } + + @Test + public void testOutputSchemaGeneration() throws Exception { + String[] commands = new String[]{ + "parse-as-csv :body ,", + "drop :body", + "set-headers :decimal_col,:name,:timestamp,:weight,:date", + "set-type :timestamp double", + }; + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("body", Schema.of(Schema.Type.STRING)), + Schema.Field.of("decimal_col", Schema.decimalOf(10, 2)) + ); + Schema expectedSchema = Schema.recordOf( + "expected", + Schema.Field.of("decimal_col", Schema.decimalOf(10, 2)), + Schema.Field.of("name", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("timestamp", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))), + Schema.Field.of("weight", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("date", Schema.nullableOf(Schema.of(Schema.Type.STRING))) + ); + List inputRows = new ArrayList<>(); + inputRows.add(new Row("body", "Larry,,186.66,01/01/2000").add("decimal_col", new BigDecimal("123.45"))); + inputRows.add(new Row("body", "Barry,1481666448,,05/01/2000").add("decimal_col", new BigDecimal("234235456.0000"))); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set( + TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, inputRows, context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + + for (Schema.Field field : expectedSchema.getFields()) { + Assert.assertEquals(field.getName(), outputSchema.getField(field.getName()).getName()); + Assert.assertEquals(field.getSchema(), outputSchema.getField(field.getName()).getSchema()); + } + } + + @Test + public void testOutputSchemaGeneration_doesNotDropNullColumn() throws Exception { + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("id", Schema.of(Schema.Type.STRING)), + Schema.Field.of("null_col", Schema.of(Schema.Type.STRING)) + ); + String[] commands = new String[]{"set-type :id int"}; + Schema expectedSchema = Schema.recordOf( + "expected", + Schema.Field.of("id", Schema.of(Schema.Type.INT)), + Schema.Field.of("null_col", Schema.of(Schema.Type.STRING)) + ); + Row row = new Row(); + row.add("id", "123"); + row.add("null_col", null); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, Collections.singletonList(row), context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + + Assert.assertEquals(outputSchema.getField("null_col").getSchema(), expectedSchema.getField("null_col").getSchema()); + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java index 9ae32bc04..55730e1ca 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java @@ -32,6 +32,7 @@ import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.RecipeParser; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientStore; import io.cdap.wrangler.executor.RecipePipelineExecutor; import io.cdap.wrangler.parser.ConfigDirectiveContext; import io.cdap.wrangler.parser.GrammarBasedParser; @@ -51,7 +52,7 @@ import io.cdap.wrangler.service.common.AbstractWranglerHandler; import io.cdap.wrangler.statistics.BasicStatistics; import io.cdap.wrangler.statistics.Statistics; -import io.cdap.wrangler.utils.SchemaConverter; +import io.cdap.wrangler.utils.TransientStoreKeys; import io.cdap.wrangler.validator.ColumnNameValidator; import io.cdap.wrangler.validator.Validator; import io.cdap.wrangler.validator.ValidatorException; @@ -62,6 +63,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -81,6 +83,7 @@ public class AbstractDirectiveHandler extends AbstractWranglerHandler { protected static final String COLUMN_NAME = "body"; protected static final String RECORD_DELIMITER_HEADER = "recorddelimiter"; protected static final String DELIMITER_HEADER = "delimiter"; + protected static final TransientStore TRANSIENT_STORE = new DefaultTransientStore(); protected DirectiveRegistry composite; @@ -133,7 +136,7 @@ protected List executeDirectives( try (RecipePipelineExecutor executor = new RecipePipelineExecutor(parser, new ServicePipelineContext( namespace, ExecutorContext.Environment.SERVICE, - getContext(), new DefaultTransientStore()))) { + getContext(), TRANSIENT_STORE))) { List result = executor.execute(sample); List errors = executor.errors() @@ -154,10 +157,19 @@ protected List executeDirectives( protected DirectiveExecutionResponse generateExecutionResponse( List rows, int limit) throws Exception { List> values = new ArrayList<>(rows.size()); - Map types = new HashMap<>(); - Set headers = new LinkedHashSet<>(); - SchemaConverter convertor = new SchemaConverter(); - + Map types = new LinkedHashMap<>(); + + Schema outputSchema = TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) != null ? + TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) : TRANSIENT_STORE.get(TransientStoreKeys.INPUT_SCHEMA); + + for (Schema.Field field : outputSchema.getFields()) { + Schema schema = field.getSchema(); + schema = schema.isNullable() ? schema.getNonNullable() : schema; + String type = schema.getLogicalType() == null ? schema.getType().name() : schema.getLogicalType().name(); + // for backward compatibility, make the characters except the first one to lower case + type = type.substring(0, 1).toUpperCase() + type.substring(1).toLowerCase(); + types.put(field.getName(), type); + } // Iterate through all the new rows. for (Row row : rows) { // If output array has more than return result values, we terminate. @@ -170,20 +182,9 @@ protected DirectiveExecutionResponse generateExecutionResponse( // Iterate through all the fields of the row. for (Pair field : row.getFields()) { String fieldName = field.getFirst(); - headers.add(fieldName); Object object = field.getSecond(); if (object != null) { - Schema schema = convertor.getSchema(object, fieldName); - String type = object.getClass().getSimpleName(); - if (schema != null) { - schema = schema.isNullable() ? schema.getNonNullable() : schema; - type = schema.getLogicalType() == null ? schema.getType().name() : schema.getLogicalType().name(); - // for backward compatibility, make the characters except the first one to lower case - type = type.substring(0, 1).toUpperCase() + type.substring(1).toLowerCase(); - } - types.put(fieldName, type); - if ((object instanceof Iterable) || (object instanceof Row)) { value.put(fieldName, GSON.toJson(object)); @@ -201,7 +202,7 @@ protected DirectiveExecutionResponse generateExecutionResponse( } values.add(value); } - return new DirectiveExecutionResponse(values, headers, types, getWorkspaceSummary(rows)); + return new DirectiveExecutionResponse(values, types.keySet(), types, getWorkspaceSummary(rows)); } /** @@ -264,6 +265,21 @@ protected WorkspaceValidationResult getWorkspaceSummary(List rows) throws E return new WorkspaceValidationResult(columnValidationResults, statistics); } + /** + * Method to get the list of columns across all the given rows + * @param rows list of rows + * @return list of columns (union across columns in all rows) + */ + public static List getAllColumns(List rows) { + Set columns = new LinkedHashSet<>(); + for (Row row : rows) { + for (int i = 0; i < row.width(); i++) { + columns.add(row.getColumn(i)); + } + } + return new ArrayList<>(columns); + } + /** * Creates a uber record after iterating through all rows. * diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index 728ffeef6..b9ef1b970 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -49,6 +49,7 @@ import io.cdap.wrangler.api.GrammarMigrator; import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.parser.ConfigDirectiveContext; import io.cdap.wrangler.parser.DirectiveClass; import io.cdap.wrangler.parser.GrammarWalker; @@ -76,8 +77,8 @@ import io.cdap.wrangler.store.recipe.RecipeStore; import io.cdap.wrangler.store.workspace.WorkspaceStore; import io.cdap.wrangler.utils.ObjectSerDe; -import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.utils.StructuredToRowTransformer; +import io.cdap.wrangler.utils.TransientStoreKeys; import org.apache.commons.lang3.StringEscapeUtils; import java.net.HttpURLConnection; @@ -400,14 +401,19 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo WorkspaceDetail detail = wsStore.getWorkspaceDetail(wsId); List directives = new ArrayList<>(detail.getWorkspace().getDirectives()); UserDirectivesCollector userDirectivesCollector = new UserDirectivesCollector(); - List result = executeDirectives(ns.getName(), directives, detail, - userDirectivesCollector); + List output = executeDirectives(ns.getName(), directives, detail, userDirectivesCollector); userDirectivesCollector.addLoadDirectivesPragma(directives); - SchemaConverter schemaConvertor = new SchemaConverter(); + Schema outputSchema = TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) != null ? + TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) : TRANSIENT_STORE.get(TransientStoreKeys.INPUT_SCHEMA); + + List fields = new ArrayList<>(); + for (String column : getAllColumns(output)) { + fields.add(outputSchema.getField(column)); + } + // check if the rows are empty before going to create a record schema, it will result in a 400 if empty fields // are passed to a record type schema - Schema schema = result.isEmpty() ? null : schemaConvertor.toSchema("record", createUberRecord(result)); Map properties = ImmutableMap.of("directives", String.join("\n", directives), "field", "*", "precondition", "false", @@ -417,8 +423,7 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo ArtifactSummary wrangler = composite.getLatestWranglerArtifact(); responder.sendString(GSON.toJson(new WorkspaceSpec( - srcSpecs, new StageSpec( - schema, new Plugin("Wrangler", "transform", properties, + srcSpecs, new StageSpec(Schema.recordOf(fields), new Plugin("Wrangler", "transform", properties, wrangler == null ? null : new Artifact(wrangler.getName(), wrangler.getVersion(), wrangler.getScope().name().toLowerCase())))))); @@ -540,6 +545,9 @@ private List executeDirectives(String namespace, GrammarWalker.Visitor grammarVisitor) throws Exception { // Remove all the #pragma from the existing directives. New ones will be generated. directives.removeIf(d -> PRAGMA_PATTERN.matcher(d).find()); + Schema inputSchema = detail.getWorkspace().getSampleSpec().getRelatedPlugins().iterator().next().getSchema(); + TRANSIENT_STORE.reset(TransientVariableScope.GLOBAL); + TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); return getContext().isRemoteTaskEnabled() ? executeRemotely(namespace, directives, detail, grammarVisitor) : From 4406bf49f56281c233f621ef2e86e7464044b252 Mon Sep 17 00:00:00 2001 From: vanathig Date: Fri, 14 Jul 2023 11:01:37 +0530 Subject: [PATCH 2/6] Cleanup test code --- .../cdap/wrangler/TestingPipelineContext.java | 58 ++++++++++--------- .../executor/RecipePipelineExecutorTest.java | 9 ++- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java index 32b3614fd..1bbfc4c33 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java @@ -24,10 +24,10 @@ import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.TransientStore; import io.cdap.wrangler.proto.Contexts; -import org.apache.commons.collections.map.HashedMap; import java.net.URL; import java.util.Collections; +import java.util.HashMap; import java.util.Map; /** @@ -35,35 +35,16 @@ * {@link Executor} execution. */ public class TestingPipelineContext implements ExecutorContext { - private StageMetrics metrics; - private String name; - private TransientStore store; - private Map properties; + private final StageMetrics metrics; + private final String name; + private final TransientStore store; + private final Map properties; public TestingPipelineContext() { - properties = new HashedMap(); + name = "testing"; + properties = new HashMap<>(); store = new DefaultTransientStore(); - } - - /** - * @return Environment this context is prepared for. - */ - @Override - public Environment getEnvironment() { - return Environment.TESTING; - } - - @Override - public String getNamespace() { - return Contexts.SYSTEM; - } - - /** - * @return Measurements context. - */ - @Override - public StageMetrics getMetrics() { - return new StageMetrics() { + metrics = new StageMetrics() { @Override public void count(String s, int i) { @@ -96,12 +77,33 @@ public Map getTags() { }; } + /** + * @return Environment this context is prepared for. + */ + @Override + public Environment getEnvironment() { + return Environment.TESTING; + } + + @Override + public String getNamespace() { + return Contexts.SYSTEM; + } + + /** + * @return Measurements context. + */ + @Override + public StageMetrics getMetrics() { + return metrics; + } + /** * @return Context name. */ @Override public String getContextName() { - return "testing"; + return name; } /** diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index 12cd96fe0..a2b90653a 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -30,7 +30,6 @@ import java.math.BigDecimal; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -62,8 +61,8 @@ public void testPipeline() throws Exception { RecipePipeline pipeline = TestingRig.execute(commands); - Row row = new Row("__col", new String("a,b,c,d,e,f,1.0")); - StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0); + Row row = new Row("__col", "a,b,c,d,e,f,1.0"); + StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0); // Validate the {@link StructuredRecord} Assert.assertEquals("a", record.get("first")); @@ -94,8 +93,8 @@ public void testPipelineWithMoreSimpleTypes() throws Exception { ); RecipePipeline pipeline = TestingRig.execute(commands); - Row row = new Row("__col", new String("Larry,Perez,lperezqt@umn.edu,1481666448,186.66")); - StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0); + Row row = new Row("__col", "Larry,Perez,lperezqt@umn.edu,1481666448,186.66"); + StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0); // Validate the {@link StructuredRecord} Assert.assertEquals("Larry", record.get("first")); From e70a715e67487d456042568a2b19c77210da51f9 Mon Sep 17 00:00:00 2001 From: vanathig Date: Thu, 20 Jul 2023 19:06:43 +0530 Subject: [PATCH 3/6] fix bug --- .../executor/RecipePipelineExecutor.java | 91 ++++++---------- .../wrangler/utils/OutputSchemaGenerator.java | 103 ++++++++++++++++++ .../executor/RecipePipelineExecutorTest.java | 36 +++++- .../service/directive/WorkspaceHandler.java | 9 +- 4 files changed, 172 insertions(+), 67 deletions(-) create mode 100644 wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 2fa103f8f..6835e0b14 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -18,7 +18,6 @@ import io.cdap.cdap.api.data.format.StructuredRecord; import io.cdap.cdap.api.data.schema.Schema; -import io.cdap.cdap.api.data.schema.Schema.Field; import io.cdap.wrangler.api.Directive; import io.cdap.wrangler.api.DirectiveExecutionException; import io.cdap.wrangler.api.ErrorRecord; @@ -32,9 +31,9 @@ import io.cdap.wrangler.api.ReportErrorAndProceed; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; +import io.cdap.wrangler.utils.OutputSchemaGenerator; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; -import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.utils.TransientStoreKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,7 +42,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** @@ -52,11 +50,9 @@ public final class RecipePipelineExecutor implements RecipePipeline { private static final Logger LOG = LoggerFactory.getLogger(RecipePipelineExecutor.class); - private static final String TEMP_SCHEMA_FIELD_NAME = "temporarySchemaField"; private final ErrorRecordCollector collector = new ErrorRecordCollector(); private final RecordConvertor convertor = new RecordConvertor(); - private final SchemaConverter generator = new SchemaConverter(); private final RecipeParser recipeParser; private final ExecutorContext context; private List directives; @@ -112,6 +108,11 @@ public List execute(List rows) throws RecipeException { List results = new ArrayList<>(); int i = 0; int directiveIndex = 0; + // Initialize schema with input schema from TransientStore if running in service env (design-time) / testing env + boolean designTime = context != null && context.getEnvironment() != null && + (context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) || + context.getEnvironment().equals(ExecutorContext.Environment.TESTING)); + try { collector.reset(); while (i < rows.size()) { @@ -120,11 +121,6 @@ public List execute(List rows) throws RecipeException { if (context != null) { context.getTransientStore().reset(TransientVariableScope.LOCAL); } - - // Initialize schema with input schema from TransientStore if running in service env (design-time) / testing env - boolean designTime = context.getEnvironment() != null && - context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) || - context.getEnvironment().equals(ExecutorContext.Environment.TESTING); Schema schema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; List cumulativeRows = rows.subList(i, i + 1); @@ -140,7 +136,7 @@ public List execute(List rows) throws RecipeException { if (designTime && schema != null) { Schema directiveOutputSchema = directive.getOutputSchema(schema); schema = directiveOutputSchema != null ? directiveOutputSchema - : generateOutputSchema(schema, cumulativeRows); + : OutputSchemaGenerator.generateOutputSchema(schema, getRowUnion(cumulativeRows)); } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); @@ -152,10 +148,26 @@ public List execute(List rows) throws RecipeException { throw new RecipeException("Error while generating schema: " + e.getMessage(), e); } } + // After executing all directives on a row, take union of previous row's schema and this one's schema + // to ensure all output columns are included in final output schema if (designTime && schema != null) { Schema previousRowSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); - schema = previousRowSchema != null ? getSchemaUnion(previousRowSchema, schema) : schema; - context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, schema); + // If this is the first row, initialize previousRowSchema with input schema fields for all columns in output + if (previousRowSchema == null) { + Schema inputSchema = context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA); + List inputFields = new ArrayList<>(); + for (Pair field : getRowUnion(cumulativeRows).getFields()) { + Schema.Field existing = inputSchema.getField(field.getFirst()); + if (existing != null) { + inputFields.add(existing); + } + } + if (!inputFields.isEmpty()) { + previousRowSchema = Schema.recordOf(inputFields); + } + } + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, + OutputSchemaGenerator.getSchemaUnion(previousRowSchema, schema)); } results.addAll(cumulativeRows); } catch (ErrorRowException e) { @@ -189,54 +201,15 @@ private List getDirectives() throws RecipeException { return directives; } - private Schema generateOutputSchema(Schema inputSchema, List output) throws RecordConvertorException { - Map outputFieldMap = new LinkedHashMap<>(); - for (Row row : output) { - for (Pair rowField : row.getFields()) { - String fieldName = rowField.getFirst(); - Object fieldValue = rowField.getSecond(); - - Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; - Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? - generator.getSchema(fieldValue, fieldName) : null; - - if (generated != null) { - outputFieldMap.put(fieldName, generated); - } else if (existing != null) { - outputFieldMap.put(fieldName, existing); + public static Row getRowUnion(List rows) { + Row union = new Row(); + for (Row row : rows) { + for (int i = 0; i < row.width(); ++i) { + if (union.find(row.getColumn(i)) == -1) { + union.add(row.getColumn(i), row.getValue(i)); } } } - List outputFields = outputFieldMap.entrySet().stream() - .map(e -> Schema.Field.of(e.getKey(), e.getValue())) - .collect(Collectors.toList()); - return Schema.recordOf("output", outputFields); - } - - // Checks whether the provided input schema is of valid type for given object - private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { - if (schema == null) { - return false; - } - Schema generated = generator.getSchema(value, TEMP_SCHEMA_FIELD_NAME); - generated = generated.isNullable() ? generated.getNonNullable() : generated; - schema = schema.isNullable() ? schema.getNonNullable() : schema; - return generated.getType().equals(schema.getType()); - } - - // Gets the union of fields in two schemas while maintaining insertion order and uniqueness of fields. If the same - // field exists with two different schemas, the second schema overwrites first one - private Schema getSchemaUnion(Schema first, Schema second) { - Map fieldMap = new LinkedHashMap<>(); - for (Field field : first.getFields()) { - fieldMap.put(field.getName(), field.getSchema()); - } - for (Field field : second.getFields()) { - fieldMap.put(field.getName(), field.getSchema()); - } - List outputFields = fieldMap.entrySet().stream() - .map(e -> Schema.Field.of(e.getKey(), e.getValue())) - .collect(Collectors.toList()); - return Schema.recordOf("union", outputFields); + return union; } } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java new file mode 100644 index 000000000..822022b2f --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java @@ -0,0 +1,103 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.utils; + +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.Pair; +import io.cdap.wrangler.api.Row; + +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import javax.annotation.Nullable; + +/** + * JAVADOC + */ +public final class OutputSchemaGenerator { + private static final String TEMP_SCHEMA_FIELD_NAME = "temporarySchemaField"; + + private static final SchemaConverter SCHEMA_GENERATOR = new SchemaConverter(); + + /** + * Method to generate the output schema for the given output rows + * @param inputSchema {@link Schema} of the data before transformation + * @param output rows of data after transformation + * @return generated {@link Schema} of the output data + * @throws RecordConvertorException + */ + public static Schema generateOutputSchema(Schema inputSchema, Row output) throws RecordConvertorException { + List outputFields = new LinkedList<>(); + for (Pair rowField : output.getFields()) { + String fieldName = rowField.getFirst(); + Object fieldValue = rowField.getSecond(); + + Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; + Schema generated = fieldValue == null ? Schema.of(Schema.Type.NULL) : + (!isValidSchemaForValue(existing, fieldValue) ? SCHEMA_GENERATOR.getSchema(fieldValue, fieldName) : null); + + if (generated != null) { + outputFields.add(Schema.Field.of(fieldName, generated)); + } else if (existing != null) { + outputFields.add(Schema.Field.of(fieldName, existing)); + } + } + return Schema.recordOf("output", outputFields); + } + + /** + * + * @param first + * @param second + * @return + */ + public static Schema getSchemaUnion(@Nullable Schema first, @Nullable Schema second) { + if (first == null) { + return second; + } + if (second == null) { + return first; + } + Map fieldMap = new LinkedHashMap<>(); + for (Schema.Field field : first.getFields()) { + fieldMap.put(field.getName(), field.getSchema()); + } + for (Schema.Field field : second.getFields()) { + if (field.getSchema().getType().equals(Schema.Type.NULL) && fieldMap.containsKey(field.getName())) { + continue; + } + fieldMap.put(field.getName(), field.getSchema()); + } + List outputFields = fieldMap.entrySet().stream() + .map(e -> Schema.Field.of(e.getKey(), e.getValue())) + .collect(Collectors.toList()); + return Schema.recordOf(TEMP_SCHEMA_FIELD_NAME, outputFields); + } + + // Checks whether the provided input schema is of valid type for given object + private static boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { + if (schema == null) { + return false; + } + Schema generated = SCHEMA_GENERATOR.getSchema(value, TEMP_SCHEMA_FIELD_NAME); + generated = generated.isNullable() ? generated.getNonNullable() : generated; + schema = schema.isNullable() ? schema.getNonNullable() : schema; + return generated.getType().equals(schema.getType()); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index a2b90653a..495904005 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -30,6 +30,7 @@ import java.math.BigDecimal; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -163,6 +164,39 @@ public void testOutputSchemaGeneration_doesNotDropNullColumn() throws Exception TestingRig.execute(commands, Collections.singletonList(row), context); Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); - Assert.assertEquals(outputSchema.getField("null_col").getSchema(), expectedSchema.getField("null_col").getSchema()); + Assert.assertEquals(expectedSchema.getField("null_col").getSchema(), outputSchema.getField("null_col").getSchema()); + } + + @Test + public void testOutputSchemaGeneration_columnOrdering() throws Exception { + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("body", Schema.of(Schema.Type.STRING)), + Schema.Field.of("value", Schema.of(Schema.Type.INT)) + ); + String[] commands = new String[] { + "parse-as-json :body 1", + "set-type :value long" + }; + List expectedFields = Arrays.asList( + Schema.Field.of("value", Schema.nullableOf(Schema.of(Schema.Type.LONG))), + Schema.Field.of("body_A", Schema.nullableOf(Schema.of(Schema.Type.LONG))), + Schema.Field.of("body_B", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("body_C", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))) + ); + Row row1 = new Row().add("body", "{\"A\":1, \"B\":\"hello\"}").add("value", 10L); + Row row2 = new Row().add("body", "{\"C\":1.23, \"A\":1, \"B\":\"world\"}").add("value", 20L); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, Arrays.asList(row1, row2), context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + List outputFields = outputSchema.getFields(); + + Assert.assertEquals(expectedFields.size(), outputFields.size()); + for (int i = 0; i < expectedFields.size(); i++) { + Assert.assertEquals(expectedFields.get(i).getName(), outputFields.get(i).getName()); + Assert.assertEquals(expectedFields.get(i).getSchema(), outputFields.get(i).getSchema()); + } } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index b9ef1b970..afee665b2 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -401,17 +401,12 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo WorkspaceDetail detail = wsStore.getWorkspaceDetail(wsId); List directives = new ArrayList<>(detail.getWorkspace().getDirectives()); UserDirectivesCollector userDirectivesCollector = new UserDirectivesCollector(); - List output = executeDirectives(ns.getName(), directives, detail, userDirectivesCollector); + executeDirectives(ns.getName(), directives, detail, userDirectivesCollector); userDirectivesCollector.addLoadDirectivesPragma(directives); Schema outputSchema = TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) != null ? TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) : TRANSIENT_STORE.get(TransientStoreKeys.INPUT_SCHEMA); - List fields = new ArrayList<>(); - for (String column : getAllColumns(output)) { - fields.add(outputSchema.getField(column)); - } - // check if the rows are empty before going to create a record schema, it will result in a 400 if empty fields // are passed to a record type schema Map properties = ImmutableMap.of("directives", String.join("\n", directives), @@ -423,7 +418,7 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo ArtifactSummary wrangler = composite.getLatestWranglerArtifact(); responder.sendString(GSON.toJson(new WorkspaceSpec( - srcSpecs, new StageSpec(Schema.recordOf(fields), new Plugin("Wrangler", "transform", properties, + srcSpecs, new StageSpec(outputSchema, new Plugin("Wrangler", "transform", properties, wrangler == null ? null : new Artifact(wrangler.getName(), wrangler.getVersion(), wrangler.getScope().name().toLowerCase())))))); From f596f7eeba6e3697c9c40c7653271680fe278b74 Mon Sep 17 00:00:00 2001 From: vanathig Date: Fri, 21 Jul 2023 16:51:23 +0530 Subject: [PATCH 4/6] Optimize --- .../executor/RecipePipelineExecutor.java | 47 +++---- .../wrangler/utils/OutputSchemaGenerator.java | 127 +++++++++++------- .../executor/RecipePipelineExecutorTest.java | 4 +- 3 files changed, 97 insertions(+), 81 deletions(-) diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 6835e0b14..15cd1169a 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -39,9 +39,7 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import javax.annotation.Nullable; /** @@ -112,6 +110,10 @@ public List execute(List rows) throws RecipeException { boolean designTime = context != null && context.getEnvironment() != null && (context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) || context.getEnvironment().equals(ExecutorContext.Environment.TESTING)); + Schema inputSchema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; + + OutputSchemaGenerator outputSchemaGenerator = designTime && inputSchema != null ? + new OutputSchemaGenerator(inputSchema, directives) : null; try { collector.reset(); @@ -121,7 +123,6 @@ public List execute(List rows) throws RecipeException { if (context != null) { context.getTransientStore().reset(TransientVariableScope.LOCAL); } - Schema schema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; List cumulativeRows = rows.subList(i, i + 1); directiveIndex = 0; @@ -133,10 +134,10 @@ public List execute(List rows) throws RecipeException { if (cumulativeRows.size() < 1) { break; } - if (designTime && schema != null) { - Schema directiveOutputSchema = directive.getOutputSchema(schema); - schema = directiveOutputSchema != null ? directiveOutputSchema - : OutputSchemaGenerator.generateOutputSchema(schema, getRowUnion(cumulativeRows)); + if (designTime && inputSchema != null) { + for (Pair field : getRowUnion(cumulativeRows).getFields()) { + outputSchemaGenerator.addDirectiveField(directiveIndex - 1, field.getFirst(), field.getSecond()); + } } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); @@ -144,30 +145,7 @@ public List execute(List rows) throws RecipeException { .add(new ErrorRecord(rows.subList(i, i + 1).get(0), String.join(",", messages), e.getCode(), true)); cumulativeRows = new ArrayList<>(); break; - } catch (RecordConvertorException e) { - throw new RecipeException("Error while generating schema: " + e.getMessage(), e); - } - } - // After executing all directives on a row, take union of previous row's schema and this one's schema - // to ensure all output columns are included in final output schema - if (designTime && schema != null) { - Schema previousRowSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); - // If this is the first row, initialize previousRowSchema with input schema fields for all columns in output - if (previousRowSchema == null) { - Schema inputSchema = context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA); - List inputFields = new ArrayList<>(); - for (Pair field : getRowUnion(cumulativeRows).getFields()) { - Schema.Field existing = inputSchema.getField(field.getFirst()); - if (existing != null) { - inputFields.add(existing); - } - } - if (!inputFields.isEmpty()) { - previousRowSchema = Schema.recordOf(inputFields); - } } - context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, - OutputSchemaGenerator.getSchemaUnion(previousRowSchema, schema)); } results.addAll(cumulativeRows); } catch (ErrorRowException e) { @@ -181,6 +159,15 @@ public List execute(List rows) throws RecipeException { } catch (DirectiveExecutionException e) { throw new RecipeException(e.getMessage(), e, i, directiveIndex); } + // Schema generation + if (designTime && inputSchema != null) { + try { + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, + outputSchemaGenerator.generateOutputSchema()); + } catch (RecordConvertorException e) { + throw new RuntimeException(e); + } + } return results; } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java index 822022b2f..8741a1e45 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java @@ -17,85 +17,114 @@ package io.cdap.wrangler.utils; import io.cdap.cdap.api.data.schema.Schema; -import io.cdap.wrangler.api.Pair; -import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.Directive; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** - * JAVADOC + * This class can be used to generate the output schema after executing a set of directives. A list is maintained + * where each element is a map of fieldName --> value. Each element of the list corresponds to a directive in the list + * provided during initialization. Hence, each map represents the fields present across all output rows generated by a + * directive after execution. */ -public final class OutputSchemaGenerator { - private static final String TEMP_SCHEMA_FIELD_NAME = "temporarySchemaField"; - +public class OutputSchemaGenerator { private static final SchemaConverter SCHEMA_GENERATOR = new SchemaConverter(); + private final Schema inputSchema; + private final List directives; + private final int directiveCount; + private final List> directiveOutputFieldMaps; + + public OutputSchemaGenerator(Schema inputSchema, List directives) { + this.inputSchema = inputSchema; + this.directives = directives; + this.directiveCount = directives.size(); + directiveOutputFieldMaps = new ArrayList<>(); + for (int i = 0; i < directiveCount; i++) { + directiveOutputFieldMaps.add(new LinkedHashMap<>()); + } + } + + /** + * Method to add a field in output generated by a directive. The field is added if not already present. + * Existing value is only overwritten when given fieldValue is not null and current value is null. + * @param directiveIndex index of the directive in recipe + * @param fieldName name of the field (column name) + * @param fieldValue value (can be null) + */ + public void addDirectiveField(int directiveIndex, String fieldName, @Nullable Object fieldValue) { + Map directiveOutputFields = directiveOutputFieldMaps.get(directiveIndex); + if (directiveOutputFields.containsKey(fieldName)) { + // If existing value is null, override with this non-null value + if (fieldValue != null && directiveOutputFields.get(fieldName) == null) { + directiveOutputFields.put(fieldName, fieldValue); + } + } else { + directiveOutputFields.putIfAbsent(fieldName, fieldValue); + } + } + /** - * Method to generate the output schema for the given output rows - * @param inputSchema {@link Schema} of the data before transformation - * @param output rows of data after transformation - * @return generated {@link Schema} of the output data - * @throws RecordConvertorException + * Method to generate the output schema after applying all directives. Intermediate schema is generated using + * the fields in corresponding map only if the directive does not provide an implementation + * of the {@link Directive#getOutputSchema(Schema)} method + * @return {@link Schema} outputSchema after all directives are applied */ - public static Schema generateOutputSchema(Schema inputSchema, Row output) throws RecordConvertorException { + public Schema generateOutputSchema() throws RecordConvertorException { + return generateIntermediateSchema(directiveCount); + } + + /** + * Method to generate the intermediate schema after applying all directives until (and including) i'th directive in + * the recipe. Intermediate schema is generated using + * the fields in corresponding map only if the directive does not provide an implementation + * of the {@link Directive#getOutputSchema(Schema)} method + * @param directiveIndex index until (and including) which schema should be generated + * @return {@link Schema} intermediate schema after the directives are applied + */ + public Schema generateIntermediateSchema(int directiveIndex) throws RecordConvertorException { + Schema schema = inputSchema; + for (int i = 0; i < directiveIndex; i++) { + Schema directiveOutputSchema = directives.get(i).getOutputSchema(schema); + schema = directiveOutputSchema != null ? directiveOutputSchema : + generateDirectiveOutputSchema(schema, directiveOutputFieldMaps.get(i)); + } + return schema; + } + + // Given the schema from previous step and output of current directive, generates the directive output schema. + private Schema generateDirectiveOutputSchema(Schema inputSchema, Map output) + throws RecordConvertorException { List outputFields = new LinkedList<>(); - for (Pair rowField : output.getFields()) { - String fieldName = rowField.getFirst(); - Object fieldValue = rowField.getSecond(); + for (String fieldName : output.keySet()) { + Object fieldValue = output.get(fieldName); Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; - Schema generated = fieldValue == null ? Schema.of(Schema.Type.NULL) : - (!isValidSchemaForValue(existing, fieldValue) ? SCHEMA_GENERATOR.getSchema(fieldValue, fieldName) : null); + Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? + SCHEMA_GENERATOR.getSchema(fieldValue, fieldName) : null; if (generated != null) { outputFields.add(Schema.Field.of(fieldName, generated)); } else if (existing != null) { outputFields.add(Schema.Field.of(fieldName, existing)); + } else { + outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL))); } } return Schema.recordOf("output", outputFields); } - /** - * - * @param first - * @param second - * @return - */ - public static Schema getSchemaUnion(@Nullable Schema first, @Nullable Schema second) { - if (first == null) { - return second; - } - if (second == null) { - return first; - } - Map fieldMap = new LinkedHashMap<>(); - for (Schema.Field field : first.getFields()) { - fieldMap.put(field.getName(), field.getSchema()); - } - for (Schema.Field field : second.getFields()) { - if (field.getSchema().getType().equals(Schema.Type.NULL) && fieldMap.containsKey(field.getName())) { - continue; - } - fieldMap.put(field.getName(), field.getSchema()); - } - List outputFields = fieldMap.entrySet().stream() - .map(e -> Schema.Field.of(e.getKey(), e.getValue())) - .collect(Collectors.toList()); - return Schema.recordOf(TEMP_SCHEMA_FIELD_NAME, outputFields); - } - // Checks whether the provided input schema is of valid type for given object - private static boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { + private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { if (schema == null) { return false; } - Schema generated = SCHEMA_GENERATOR.getSchema(value, TEMP_SCHEMA_FIELD_NAME); + Schema generated = SCHEMA_GENERATOR.getSchema(value, "temp_field_name"); generated = generated.isNullable() ? generated.getNonNullable() : generated; schema = schema.isNullable() ? schema.getNonNullable() : schema; return generated.getType().equals(schema.getType()); diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index 495904005..84b96ba6e 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -127,8 +127,8 @@ public void testOutputSchemaGeneration() throws Exception { Schema.Field.of("date", Schema.nullableOf(Schema.of(Schema.Type.STRING))) ); List inputRows = new ArrayList<>(); - inputRows.add(new Row("body", "Larry,,186.66,01/01/2000").add("decimal_col", new BigDecimal("123.45"))); - inputRows.add(new Row("body", "Barry,1481666448,,05/01/2000").add("decimal_col", new BigDecimal("234235456.0000"))); + inputRows.add(new Row("body", "Larry,1481666448,01/01/2000").add("decimal_col", new BigDecimal("123.45"))); + inputRows.add(new Row("body", "Barry,,172.3,05/01/2000").add("decimal_col", new BigDecimal("234235456.0000"))); ExecutorContext context = new TestingPipelineContext(); context.getTransientStore().set( TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); From 9c4533798ac7c794443546ddbf1171e8d1ea0203 Mon Sep 17 00:00:00 2001 From: vanathig Date: Mon, 24 Jul 2023 21:40:39 +0530 Subject: [PATCH 5/6] Refactor output schema generation --- .../executor/RecipePipelineExecutor.java | 42 +++--- .../utils/DirectiveOutputSchemaGenerator.java | 111 +++++++++++++++ .../wrangler/utils/OutputSchemaGenerator.java | 132 ------------------ 3 files changed, 132 insertions(+), 153 deletions(-) create mode 100644 wrangler-core/src/main/java/io/cdap/wrangler/utils/DirectiveOutputSchemaGenerator.java delete mode 100644 wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 15cd1169a..106c77292 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -24,16 +24,16 @@ import io.cdap.wrangler.api.ErrorRowException; import io.cdap.wrangler.api.Executor; import io.cdap.wrangler.api.ExecutorContext; -import io.cdap.wrangler.api.Pair; import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.RecipeParser; import io.cdap.wrangler.api.RecipePipeline; import io.cdap.wrangler.api.ReportErrorAndProceed; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; -import io.cdap.wrangler.utils.OutputSchemaGenerator; +import io.cdap.wrangler.utils.DirectiveOutputSchemaGenerator; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; +import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.utils.TransientStoreKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,6 +51,7 @@ public final class RecipePipelineExecutor implements RecipePipeline directives; @@ -112,8 +113,12 @@ public List execute(List rows) throws RecipeException { context.getEnvironment().equals(ExecutorContext.Environment.TESTING)); Schema inputSchema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; - OutputSchemaGenerator outputSchemaGenerator = designTime && inputSchema != null ? - new OutputSchemaGenerator(inputSchema, directives) : null; + List outputSchemaGenerators = new ArrayList<>(); + if (designTime && inputSchema != null) { + for (Directive directive : directives) { + outputSchemaGenerators.add(new DirectiveOutputSchemaGenerator(directive, generator)); + } + } try { collector.reset(); @@ -135,9 +140,7 @@ public List execute(List rows) throws RecipeException { break; } if (designTime && inputSchema != null) { - for (Pair field : getRowUnion(cumulativeRows).getFields()) { - outputSchemaGenerator.addDirectiveField(directiveIndex - 1, field.getFirst(), field.getSecond()); - } + outputSchemaGenerators.get(directiveIndex - 1).addNewOutputFields(cumulativeRows); } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); @@ -161,12 +164,8 @@ public List execute(List rows) throws RecipeException { } // Schema generation if (designTime && inputSchema != null) { - try { - context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, - outputSchemaGenerator.generateOutputSchema()); - } catch (RecordConvertorException e) { - throw new RuntimeException(e); - } + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, + getOutputSchema(inputSchema, outputSchemaGenerators)); } return results; } @@ -188,15 +187,16 @@ private List getDirectives() throws RecipeException { return directives; } - public static Row getRowUnion(List rows) { - Row union = new Row(); - for (Row row : rows) { - for (int i = 0; i < row.width(); ++i) { - if (union.find(row.getColumn(i)) == -1) { - union.add(row.getColumn(i), row.getValue(i)); - } + private Schema getOutputSchema(Schema inputSchema, List outputSchemaGenerators) + throws RecipeException { + Schema schema = inputSchema; + for (DirectiveOutputSchemaGenerator outputSchemaGenerator : outputSchemaGenerators) { + try { + schema = outputSchemaGenerator.getDirectiveOutputSchema(schema); + } catch (RecordConvertorException e) { + throw new RecipeException("Error while generating output schema for a directive: " + e, e); } } - return union; + return schema; } } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/DirectiveOutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/DirectiveOutputSchemaGenerator.java new file mode 100644 index 000000000..eb90ab2c4 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/DirectiveOutputSchemaGenerator.java @@ -0,0 +1,111 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.utils; + +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.Pair; +import io.cdap.wrangler.api.Row; + +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * This class can be used to generate the output schema for the output data of a directive. It maintains a map of + * output fields present across all output rows after applying a directive. This map is used to generate the schema + * if the directive does not return a custom output schema. + */ +public class DirectiveOutputSchemaGenerator { + private final SchemaConverter schemaGenerator; + private final Map outputFieldMap; + private final Directive directive; + + public DirectiveOutputSchemaGenerator(Directive directive, SchemaConverter schemaGenerator) { + this.directive = directive; + this.schemaGenerator = schemaGenerator; + outputFieldMap = new LinkedHashMap<>(); + } + + /** + * Method to add new fields from the given output to the map of fieldName --> value maintained for schema generation. + * A value is added to the map only if it is absent (or) if the existing value is null and given value is non-null + * @param output list of output {@link Row}s after applying directive. + */ + public void addNewOutputFields(List output) { + for (Row row : output) { + for (Pair field : row.getFields()) { + String fieldName = field.getFirst(); + Object fieldValue = field.getSecond(); + if (outputFieldMap.containsKey(fieldName)) { + // If existing value is null, override with this non-null value + if (fieldValue != null && outputFieldMap.get(fieldName) == null) { + outputFieldMap.put(fieldName, fieldValue); + } + } else { + outputFieldMap.putIfAbsent(fieldName, fieldValue); + } + } + } + } + + /** + * Method to get the output schema of the directive. Returns a generated schema based on maintained map of fields + * only if directive does not return a custom output schema. + * @param inputSchema input {@link Schema} of the data before applying the directive + * @return {@link Schema} corresponding to the output data + */ + public Schema getDirectiveOutputSchema(Schema inputSchema) throws RecordConvertorException { + Schema directiveOutputSchema = directive.getOutputSchema(inputSchema); + return directiveOutputSchema != null ? directiveOutputSchema : generateDirectiveOutputSchema(inputSchema); + } + + // Given the schema from previous step and output of current directive, generates the directive output schema. + private Schema generateDirectiveOutputSchema(Schema inputSchema) + throws RecordConvertorException { + List outputFields = new LinkedList<>(); + for (String fieldName : outputFieldMap.keySet()) { + Object fieldValue = outputFieldMap.get(fieldName); + + Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; + Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? + schemaGenerator.getSchema(fieldValue, fieldName) : null; + + if (generated != null) { + outputFields.add(Schema.Field.of(fieldName, generated)); + } else if (existing != null) { + outputFields.add(Schema.Field.of(fieldName, existing)); + } else { + outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL))); + } + } + return Schema.recordOf("output", outputFields); + } + + // Checks whether the provided input schema is of valid type for given object + private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { + if (schema == null) { + return false; + } + Schema generated = schemaGenerator.getSchema(value, "temp_field_name"); + generated = generated.isNullable() ? generated.getNonNullable() : generated; + schema = schema.isNullable() ? schema.getNonNullable() : schema; + return generated.getLogicalType() == schema.getLogicalType() && generated.getType() == schema.getType(); + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java deleted file mode 100644 index 8741a1e45..000000000 --- a/wrangler-core/src/main/java/io/cdap/wrangler/utils/OutputSchemaGenerator.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright © 2023 Cask Data, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ - -package io.cdap.wrangler.utils; - -import io.cdap.cdap.api.data.schema.Schema; -import io.cdap.wrangler.api.Directive; - -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import javax.annotation.Nullable; - -/** - * This class can be used to generate the output schema after executing a set of directives. A list is maintained - * where each element is a map of fieldName --> value. Each element of the list corresponds to a directive in the list - * provided during initialization. Hence, each map represents the fields present across all output rows generated by a - * directive after execution. - */ -public class OutputSchemaGenerator { - private static final SchemaConverter SCHEMA_GENERATOR = new SchemaConverter(); - - private final Schema inputSchema; - private final List directives; - private final int directiveCount; - private final List> directiveOutputFieldMaps; - - public OutputSchemaGenerator(Schema inputSchema, List directives) { - this.inputSchema = inputSchema; - this.directives = directives; - this.directiveCount = directives.size(); - directiveOutputFieldMaps = new ArrayList<>(); - for (int i = 0; i < directiveCount; i++) { - directiveOutputFieldMaps.add(new LinkedHashMap<>()); - } - } - - /** - * Method to add a field in output generated by a directive. The field is added if not already present. - * Existing value is only overwritten when given fieldValue is not null and current value is null. - * @param directiveIndex index of the directive in recipe - * @param fieldName name of the field (column name) - * @param fieldValue value (can be null) - */ - public void addDirectiveField(int directiveIndex, String fieldName, @Nullable Object fieldValue) { - Map directiveOutputFields = directiveOutputFieldMaps.get(directiveIndex); - if (directiveOutputFields.containsKey(fieldName)) { - // If existing value is null, override with this non-null value - if (fieldValue != null && directiveOutputFields.get(fieldName) == null) { - directiveOutputFields.put(fieldName, fieldValue); - } - } else { - directiveOutputFields.putIfAbsent(fieldName, fieldValue); - } - } - - /** - * Method to generate the output schema after applying all directives. Intermediate schema is generated using - * the fields in corresponding map only if the directive does not provide an implementation - * of the {@link Directive#getOutputSchema(Schema)} method - * @return {@link Schema} outputSchema after all directives are applied - */ - public Schema generateOutputSchema() throws RecordConvertorException { - return generateIntermediateSchema(directiveCount); - } - - /** - * Method to generate the intermediate schema after applying all directives until (and including) i'th directive in - * the recipe. Intermediate schema is generated using - * the fields in corresponding map only if the directive does not provide an implementation - * of the {@link Directive#getOutputSchema(Schema)} method - * @param directiveIndex index until (and including) which schema should be generated - * @return {@link Schema} intermediate schema after the directives are applied - */ - public Schema generateIntermediateSchema(int directiveIndex) throws RecordConvertorException { - Schema schema = inputSchema; - for (int i = 0; i < directiveIndex; i++) { - Schema directiveOutputSchema = directives.get(i).getOutputSchema(schema); - schema = directiveOutputSchema != null ? directiveOutputSchema : - generateDirectiveOutputSchema(schema, directiveOutputFieldMaps.get(i)); - } - return schema; - } - - // Given the schema from previous step and output of current directive, generates the directive output schema. - private Schema generateDirectiveOutputSchema(Schema inputSchema, Map output) - throws RecordConvertorException { - List outputFields = new LinkedList<>(); - for (String fieldName : output.keySet()) { - Object fieldValue = output.get(fieldName); - - Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; - Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? - SCHEMA_GENERATOR.getSchema(fieldValue, fieldName) : null; - - if (generated != null) { - outputFields.add(Schema.Field.of(fieldName, generated)); - } else if (existing != null) { - outputFields.add(Schema.Field.of(fieldName, existing)); - } else { - outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL))); - } - } - return Schema.recordOf("output", outputFields); - } - - // Checks whether the provided input schema is of valid type for given object - private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { - if (schema == null) { - return false; - } - Schema generated = SCHEMA_GENERATOR.getSchema(value, "temp_field_name"); - generated = generated.isNullable() ? generated.getNonNullable() : generated; - schema = schema.isNullable() ? schema.getNonNullable() : schema; - return generated.getType().equals(schema.getType()); - } -} From f7fd80dad8652ff45626bfa86e5e40478a7686fa Mon Sep 17 00:00:00 2001 From: vanathig Date: Tue, 25 Jul 2023 00:34:24 +0530 Subject: [PATCH 6/6] Resolve comments --- .../java/io/cdap/wrangler/api/Executor.java | 7 ++-- .../wrangler/api/SchemaResolutionContext.java | 29 +++++++++++++++ .../executor/RecipePipelineExecutor.java | 7 ++-- .../DirectiveOutputSchemaGenerator.java | 28 +++++++++------ .../DirectiveSchemaResolutionContext.java | 36 +++++++++++++++++++ .../{utils => schema}/TransientStoreKeys.java | 2 +- .../executor/RecipePipelineExecutorTest.java | 8 ++--- .../directive/AbstractDirectiveHandler.java | 19 +--------- .../service/directive/WorkspaceHandler.java | 2 +- 9 files changed, 98 insertions(+), 40 deletions(-) create mode 100644 wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java rename wrangler-core/src/main/java/io/cdap/wrangler/{utils => schema}/DirectiveOutputSchemaGenerator.java (82%) create mode 100644 wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java rename wrangler-core/src/main/java/io/cdap/wrangler/{utils => schema}/TransientStoreKeys.java (96%) diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java index c90374c39..8c85319e4 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java @@ -85,14 +85,15 @@ O execute(I rows, ExecutorContext context) /** * This method is used to get the updated schema of the data after the directive's transformation has been applied. + * + * @param schemaResolutionContext context containing necessary information for getting output schema + * @return output {@link Schema} of the transformed data * @implNote By default, returns a null and the schema is inferred from the data when necessary. *

For consistent handling, override for directives that perform column renames, * column data type changes or column additions with specific schemas.

- * @param inputSchema input {@link Schema} of the data before transformation - * @return output {@link Schema} of the transformed data */ @Nullable - default Schema getOutputSchema(Schema inputSchema) { + default Schema getOutputSchema(SchemaResolutionContext schemaResolutionContext) { // no op return null; } diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java new file mode 100644 index 000000000..015f8bdc6 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java @@ -0,0 +1,29 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.api; + +import io.cdap.cdap.api.data.schema.Schema; + +/** + * Interface to pass contextual information related to getting or generating the output schema of a {@link Executor} + */ +public interface SchemaResolutionContext { + /** + * @return {@link Schema} of the input data before transformation + */ + Schema getInputSchema(); +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 106c77292..a41206e31 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -30,11 +30,12 @@ import io.cdap.wrangler.api.ReportErrorAndProceed; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; -import io.cdap.wrangler.utils.DirectiveOutputSchemaGenerator; +import io.cdap.wrangler.schema.DirectiveOutputSchemaGenerator; +import io.cdap.wrangler.schema.DirectiveSchemaResolutionContext; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; import io.cdap.wrangler.utils.SchemaConverter; -import io.cdap.wrangler.utils.TransientStoreKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -192,7 +193,7 @@ private Schema getOutputSchema(Schema inputSchema, List output) { outputFieldMap.put(fieldName, fieldValue); } } else { - outputFieldMap.putIfAbsent(fieldName, fieldValue); + outputFieldMap.put(fieldName, fieldValue); } } } @@ -68,20 +71,22 @@ public void addNewOutputFields(List output) { /** * Method to get the output schema of the directive. Returns a generated schema based on maintained map of fields * only if directive does not return a custom output schema. - * @param inputSchema input {@link Schema} of the data before applying the directive + * @param context input {@link Schema} of the data before applying the directive * @return {@link Schema} corresponding to the output data */ - public Schema getDirectiveOutputSchema(Schema inputSchema) throws RecordConvertorException { - Schema directiveOutputSchema = directive.getOutputSchema(inputSchema); - return directiveOutputSchema != null ? directiveOutputSchema : generateDirectiveOutputSchema(inputSchema); + public Schema getDirectiveOutputSchema(SchemaResolutionContext context) throws RecordConvertorException { + Schema directiveOutputSchema = directive.getOutputSchema(context); + return directiveOutputSchema != null ? directiveOutputSchema : + generateDirectiveOutputSchema(context.getInputSchema()); } // Given the schema from previous step and output of current directive, generates the directive output schema. private Schema generateDirectiveOutputSchema(Schema inputSchema) throws RecordConvertorException { - List outputFields = new LinkedList<>(); - for (String fieldName : outputFieldMap.keySet()) { - Object fieldValue = outputFieldMap.get(fieldName); + List outputFields = new ArrayList<>(); + for (Map.Entry field : outputFieldMap.entrySet()) { + String fieldName = field.getKey(); + Object fieldValue = field.getValue(); Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? @@ -90,6 +95,9 @@ private Schema generateDirectiveOutputSchema(Schema inputSchema) if (generated != null) { outputFields.add(Schema.Field.of(fieldName, generated)); } else if (existing != null) { + if (!existing.isNullable()) { + existing = Schema.nullableOf(existing); + } outputFields.add(Schema.Field.of(fieldName, existing)); } else { outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL))); diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java new file mode 100644 index 000000000..9c4c702fb --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java @@ -0,0 +1,36 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.schema; + +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.SchemaResolutionContext; + +/** + * Context to pass information related to getting or generating the output schema of a {@link Directive} + */ +public class DirectiveSchemaResolutionContext implements SchemaResolutionContext { + private final Schema inputSchema; + public DirectiveSchemaResolutionContext(Schema inputSchema) { + this.inputSchema = inputSchema; + } + + @Override + public Schema getInputSchema() { + return inputSchema; + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java similarity index 96% rename from wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java rename to wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java index d393e6656..e35ef803f 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/utils/TransientStoreKeys.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java @@ -14,7 +14,7 @@ * the License. */ -package io.cdap.wrangler.utils; +package io.cdap.wrangler.schema; /** * TransientStoreKeys for storing Workspace schema in TransientStore diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index 84b96ba6e..8b858d50c 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -24,7 +24,7 @@ import io.cdap.wrangler.api.RecipePipeline; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; -import io.cdap.wrangler.utils.TransientStoreKeys; +import io.cdap.wrangler.schema.TransientStoreKeys; import org.junit.Assert; import org.junit.Test; @@ -120,7 +120,7 @@ public void testOutputSchemaGeneration() throws Exception { ); Schema expectedSchema = Schema.recordOf( "expected", - Schema.Field.of("decimal_col", Schema.decimalOf(10, 2)), + Schema.Field.of("decimal_col", Schema.nullableOf(Schema.decimalOf(10, 2))), Schema.Field.of("name", Schema.nullableOf(Schema.of(Schema.Type.STRING))), Schema.Field.of("timestamp", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))), Schema.Field.of("weight", Schema.nullableOf(Schema.of(Schema.Type.STRING))), @@ -152,8 +152,8 @@ public void testOutputSchemaGeneration_doesNotDropNullColumn() throws Exception String[] commands = new String[]{"set-type :id int"}; Schema expectedSchema = Schema.recordOf( "expected", - Schema.Field.of("id", Schema.of(Schema.Type.INT)), - Schema.Field.of("null_col", Schema.of(Schema.Type.STRING)) + Schema.Field.of("id", Schema.nullableOf(Schema.of(Schema.Type.INT))), + Schema.Field.of("null_col", Schema.nullableOf(Schema.of(Schema.Type.STRING))) ); Row row = new Row(); row.add("id", "123"); diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java index 55730e1ca..9080fbed5 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java @@ -49,10 +49,10 @@ import io.cdap.wrangler.registry.DirectiveRegistry; import io.cdap.wrangler.registry.SystemDirectiveRegistry; import io.cdap.wrangler.registry.UserDirectiveRegistry; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.service.common.AbstractWranglerHandler; import io.cdap.wrangler.statistics.BasicStatistics; import io.cdap.wrangler.statistics.Statistics; -import io.cdap.wrangler.utils.TransientStoreKeys; import io.cdap.wrangler.validator.ColumnNameValidator; import io.cdap.wrangler.validator.Validator; import io.cdap.wrangler.validator.ValidatorException; @@ -64,7 +64,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -265,21 +264,6 @@ protected WorkspaceValidationResult getWorkspaceSummary(List rows) throws E return new WorkspaceValidationResult(columnValidationResults, statistics); } - /** - * Method to get the list of columns across all the given rows - * @param rows list of rows - * @return list of columns (union across columns in all rows) - */ - public static List getAllColumns(List rows) { - Set columns = new LinkedHashSet<>(); - for (Row row : rows) { - for (int i = 0; i < row.width(); i++) { - columns.add(row.getColumn(i)); - } - } - return new ArrayList<>(columns); - } - /** * Creates a uber record after iterating through all rows. * @@ -301,5 +285,4 @@ public static Row createUberRecord(List rows) { } return uber; } - } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index afee665b2..f450ecd42 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -74,11 +74,11 @@ import io.cdap.wrangler.proto.workspace.v2.WorkspaceUpdateRequest; import io.cdap.wrangler.registry.DirectiveInfo; import io.cdap.wrangler.registry.SystemDirectiveRegistry; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.store.recipe.RecipeStore; import io.cdap.wrangler.store.workspace.WorkspaceStore; import io.cdap.wrangler.utils.ObjectSerDe; import io.cdap.wrangler.utils.StructuredToRowTransformer; -import io.cdap.wrangler.utils.TransientStoreKeys; import org.apache.commons.lang3.StringEscapeUtils; import java.net.HttpURLConnection;