diff --git a/rdt_visitor.go b/rdt_visitor.go index c4d44d6..a707baa 100644 --- a/rdt_visitor.go +++ b/rdt_visitor.go @@ -44,15 +44,16 @@ func (visitor *RdtVisitor) Visit(tree antlr.ParseTree, target *UnknownShape) (Sh return nil, fmt.Errorf("unknown node type %T", tree) } -func (visitor *RdtVisitor) VisitChildren(node antlr.RuleNode, target *UnknownShape) ([]*BaseShape, error) { +func (visitor *RdtVisitor) VisitUnionMembers(node antlr.RuleNode, target *UnknownShape) ([]*BaseShape, error) { var shapes []*BaseShape - for _, n := range node.GetChildren() { - // Skip terminal nodes - if _, ok := n.(*antlr.TerminalNodeImpl); ok { - continue - } + children := node.GetChildren() + // Each union member is paired with a pipe separator so we increment by 2. + // type1 | type2 | type3 + // ^ ^ ^ ^ ^ + // 0 1 2 3 4 + for i := 0; i < len(children); i += 2 { baseResolved, implicitAnonShape, _ := visitor.raml.MakeNewShape("", "", target.Location, &target.Position) - s, err := visitor.Visit(n.(antlr.ParseTree), implicitAnonShape.(*UnknownShape)) + s, err := visitor.Visit(children[i].(antlr.ParseTree), implicitAnonShape.(*UnknownShape)) if err != nil { return nil, fmt.Errorf("visit children: %w", err) } @@ -130,7 +131,7 @@ func (visitor *RdtVisitor) VisitArray(ctx *rdt.ArrayContext, target *UnknownShap } func (visitor *RdtVisitor) VisitUnion(ctx *rdt.UnionContext, target *UnknownShape) (Shape, error) { - ss, err := visitor.VisitChildren(ctx, target) + ss, err := visitor.VisitUnionMembers(ctx, target) if err != nil { return nil, fmt.Errorf("visit children: %w", err) }