diff --git a/cmd/avrogo/generate.go b/cmd/avrogo/generate.go index dc0aaf0..17b79df 100644 --- a/cmd/avrogo/generate.go +++ b/cmd/avrogo/generate.go @@ -23,6 +23,25 @@ const ( const nullType = "avrotypegen.Null" +// shouldImportAvroTypeGen return true if avrotypegen is required. It checks that the definitions given are of type +// schema.RecordDefinition by looking at their match within given parsed namespace +func shouldImportAvroTypeGen(namespace *parser.Namespace, definitions []schema.QualifiedName) bool { + for _, def := range namespace.Definitions { + defToGenerateIdx := sort.Search(len(definitions), func(i int) bool { + return definitions[i].Name == def.AvroName().Name + }) + if defToGenerateIdx < len(definitions) && def.AvroName().Name == definitions[defToGenerateIdx].Name { + if _, ok := def.(*schema.RecordDefinition); ok { + return true + } + if _, ok := def.(*schema.FixedDefinition); ok { + return true + } + } + } + return false +} + func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schema.QualifiedName) error { extTypes, err := externalTypeMap(ns) if err != nil { @@ -43,11 +62,8 @@ func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schem extTypes: extTypes, } // Add avrotypegen package conditionally when there is a RecordDefinition in the namespace. - for _, def := range ns.Definitions { - if _, ok := def.(*schema.RecordDefinition); ok { - gc.addImport("github.com/heetch/avro/avrotypegen") - break - } + if shouldImportAvroTypeGen(ns, definitions) { + gc.addImport("github.com/heetch/avro/avrotypegen") } var body bytes.Buffer if err := bodyTemplate.Execute(&body, bodyTemplateParams{ diff --git a/cmd/avrogo/generate_test.go b/cmd/avrogo/generate_test.go new file mode 100644 index 0000000..ccf0d7f --- /dev/null +++ b/cmd/avrogo/generate_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "github.com/actgardner/gogen-avro/v10/parser" + "github.com/actgardner/gogen-avro/v10/schema" + "testing" + + avro "github.com/actgardner/gogen-avro/v10/schema" + qt "github.com/frankban/quicktest" +) + +func TestShouldImportAvroTypeGen(t *testing.T) { + var eventNameQualifiedName = schema.QualifiedName{Namespace: "EventName", Name: "EventName"} + var eventNameAsRecordDefinition = schema.NewRecordDefinition(eventNameQualifiedName, []avro.QualifiedName{}, []*avro.Field{}, "", map[string]interface{}{}) + var eventNameAsFixedFieldDefinition = schema.NewFixedDefinition(eventNameQualifiedName, []avro.QualifiedName{}, 142, map[string]interface{}{}) + var modelDefinitionQualifiedName = schema.QualifiedName{Namespace: "ModelDefinition", Name: "ModelDefinition"} + var modelAsEnumDefinition = schema.NewEnumDefinition(modelDefinitionQualifiedName, []avro.QualifiedName{}, []string{"", ""}, "", "defaultValue", map[string]interface{}{}) + + var shouldImportAvroTypeGenTests = []struct { + testName string + namespace *parser.Namespace + definitions []schema.QualifiedName + shouldImportAvroTypeGen bool + }{ + { + testName: "true-definition-only-present-in-namespace-and-is-record-type", + namespace: &parser.Namespace{ + Definitions: map[schema.QualifiedName]schema.Definition{ + eventNameQualifiedName: eventNameAsRecordDefinition, + }, + }, + definitions: []schema.QualifiedName{eventNameQualifiedName}, + shouldImportAvroTypeGen: true, + }, + { + testName: "true-definition-only-present-in-namespace-and-is-fixed-type", + namespace: &parser.Namespace{ + Definitions: map[schema.QualifiedName]schema.Definition{ + eventNameQualifiedName: eventNameAsFixedFieldDefinition, + }, + }, + definitions: []schema.QualifiedName{eventNameQualifiedName}, + shouldImportAvroTypeGen: true, + }, + { + testName: "true-definition-present-in-namespace-and-is-record-type", + namespace: &parser.Namespace{ + Definitions: map[schema.QualifiedName]schema.Definition{ + eventNameQualifiedName: eventNameAsRecordDefinition, + modelDefinitionQualifiedName: modelAsEnumDefinition, + }, + }, + definitions: []schema.QualifiedName{eventNameQualifiedName}, + shouldImportAvroTypeGen: true, + }, + { + testName: "false-definition-present-in-namespace-and-not-record-type", + namespace: &parser.Namespace{ + Definitions: map[schema.QualifiedName]schema.Definition{ + eventNameQualifiedName: eventNameAsRecordDefinition, + modelDefinitionQualifiedName: modelAsEnumDefinition, + }, + }, + definitions: []schema.QualifiedName{modelDefinitionQualifiedName}, + shouldImportAvroTypeGen: false, + }, + } + + c := qt.New(t) + + for _, test := range shouldImportAvroTypeGenTests { + c.Run(test.testName, func(c *qt.C) { + value := shouldImportAvroTypeGen(test.namespace, test.definitions) + c.Assert(value, qt.Equals, test.shouldImportAvroTypeGen) + }) + } +}