diff --git a/go.mod b/go.mod index 0bb17b7..0a8060f 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df golang.org/x/text v0.11.0 + github.com/dave/jennifer v1.7.0 ) require ( diff --git a/go.sum b/go.sum index ca63669..e0a7889 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/dave/jennifer v1.7.0 h1:uRbSBH9UTS64yXbh4FrMHfgfY762RD+C7bUPKODpSJE= +github.com/dave/jennifer v1.7.0/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/codegen/enum.go b/internal/codegen/enum.go index e7fc4ef..8a4e00c 100644 --- a/internal/codegen/enum.go +++ b/internal/codegen/enum.go @@ -3,8 +3,9 @@ package codegen import ( "fmt" "path" - "strings" + "github.com/dave/jennifer/jen" + "github.com/ethanmoffat/eolib-go/internal/codegen/types" "github.com/ethanmoffat/eolib-go/internal/xml" ) @@ -16,49 +17,58 @@ func GenerateEnums(outputDir string, enums []xml.ProtocolEnum) error { return err } - output := strings.Builder{} - output.WriteString(packageName + "\n\n") - output.WriteString("import \"fmt\"\n\n") + f := jen.NewFile(packageName) for _, e := range enums { - writeTypeComment(&output, e.Name, e.Comment) + writeTypeCommentJen(f, e.Name, e.Comment) - output.WriteString(fmt.Sprintf("type %s int\n\n", e.Name)) - output.WriteString("const (\n") + f.Type().Id(e.Name).Int() + defsList := make([]jen.Code, len(e.Values)) expected := 0 for i, v := range e.Values { + var s *jen.Statement if i == 0 { - output.WriteString(fmt.Sprintf("\t%s_%s %s = iota", sanitizeTypeName(e.Name), v.Name, e.Name)) + s = jen.Id(fmt.Sprintf("%s_%s", types.SanitizeTypeName(e.Name), v.Name)).Qual("", e.Name).Op("=").Iota() + if v.Value > 0 { - output.WriteString(fmt.Sprintf(" + %d", v.Value)) expected = int(v.Value) + s.Op("+").Lit(expected) } } else { - output.WriteString(fmt.Sprintf("\t%s_%s", sanitizeTypeName(e.Name), v.Name)) - if expected != int(v.Value) { - output.WriteString(fmt.Sprintf(" = %d", v.Value)) + s = jen.Id(fmt.Sprintf("%s_%s", types.SanitizeTypeName(e.Name), v.Name)) + actual := int(v.Value) + if expected != actual { + s.Op("=").Lit(actual) } } - writeInlineComment(&output, v.Comment) - - output.WriteString("\n") + writeInlineCommentJen(s, v.Comment) expected += 1 - } - output.WriteString(")\n\n") + defsList[i] = s + } + f.Const().Defs(defsList...) - output.WriteString(fmt.Sprintf("// String converts a %s value into its string representation\n", e.Name)) - output.WriteString(fmt.Sprintf("func (e %s) String() (string, error) {\n", e.Name)) - output.WriteString("\tswitch e {\n") - for _, v := range e.Values { - output.WriteString(fmt.Sprintf("\tcase %s_%s:\n\t\treturn \"%s\", nil\n", sanitizeTypeName(e.Name), v.Name, v.Name)) + caseList := make([]jen.Code, len(e.Values)+1) + for ndx, v := range e.Values { + caseList[ndx] = jen.Case(jen.Id(fmt.Sprintf("%s_%s", types.SanitizeTypeName(e.Name), v.Name))).Block(jen.Return(jen.Lit(v.Name), jen.Nil())) } - output.WriteString(fmt.Sprintf("\tdefault:\n\t\treturn \"\", fmt.Errorf(\"could not convert value %%d of type %s to string\", e)\n", e.Name)) - output.WriteString("\t}\n}\n\n") + caseList[len(e.Values)] = jen.Default().Block().Return( + jen.Lit(""), jen.Qual("fmt", "Errorf").Call(jen.Lit(fmt.Sprintf("could not convert value %%d of type %s to string", e.Name)), jen.Id("e")), + ) + + f.Commentf("String converts a %s value into its string representation", e.Name) + f.Func().Params( + // enum receiver + jen.Id("e").Id(e.Name), + ).Id("String").Params( + // empty parameter list + ).Params(jen.String(), jen.Error()).Block( + jen.Switch(jen.Id("e").Block(caseList...)), + ) } outFileName := path.Join(outputDir, enumFileName) - return writeToFile(outFileName, output.String()) + return writeToFileJen(f, outFileName) } diff --git a/internal/codegen/packet.go b/internal/codegen/packet.go index 326266b..e88199f 100644 --- a/internal/codegen/packet.go +++ b/internal/codegen/packet.go @@ -3,41 +3,59 @@ package codegen import ( "fmt" "path" - "strings" + "github.com/dave/jennifer/jen" + "github.com/ethanmoffat/eolib-go/internal/codegen/types" "github.com/ethanmoffat/eolib-go/internal/xml" ) func GeneratePackets(outputDir string, packets []xml.ProtocolPacket, fullSpec xml.Protocol) error { + if len(packets) == 0 { + return nil + } + packageName, err := getPackageName(outputDir) if err != nil { return err } - output := strings.Builder{} - output.WriteString(packageName + "\n\n") - output.WriteString("import (\n\t\"fmt\"\n\t\"reflect\"\n\t\"github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net\"\n)\n\n") - output.WriteString("var packetMap = map[int]reflect.Type{\n") + f := jen.NewFile(packageName) + types.AddImports(f) // collect type names to generate packet structs var typeNames []string - for _, p := range packets { - typeNames = append(typeNames, p.GetTypeName()) + f.Var().Id("packetMap").Op("=").Map(jen.Int()).Qual("reflect", "Type").BlockFunc(func(g *jen.Group) { + // Note that this block is using "BlockFunc" + // Official docs advices to use "Values" with "DictFunc". However, default sorting is alphabetical, which + // creates a nasty git diff of the existing generated code + for _, p := range packets { + typeNames = append(typeNames, p.GetTypeName()) - output.WriteString(fmt.Sprintf("\tnet.PacketId(net.PacketFamily_%s, net.PacketAction_%s): ", p.Family, p.Action)) - output.WriteString(fmt.Sprintf("reflect.TypeOf(%s{}),\n", snakeCaseToCamelCase(p.GetTypeName()))) - } + g.Qual(types.PackagePath("net"), "PacketId").Call( + jen.Qual(types.PackagePath("net"), fmt.Sprintf("PacketFamily_%s", p.Family)), + jen.Qual(types.PackagePath("net"), fmt.Sprintf("PacketAction_%s", p.Action)), + ).Op(":").Qual("reflect", "TypeOf").Call( + jen.Id(snakeCaseToCamelCase(p.GetTypeName())).Values(), + ).Op(",") + } + }) - output.WriteString("}\n") + f.Comment("PacketFromId creates a typed packet instance from a [net.PacketFamily] and [net.PacketAction].") + f.Comment("This function calls [PacketFromIntegerId] internally.") - output.WriteString(` -// PacketFromId creates a typed packet instance from a [net.PacketFamily] and [net.PacketAction]. -// This function calls [PacketFromIntegerId] internally. -func PacketFromId(family net.PacketFamily, action net.PacketAction) (net.Packet, error) { - return PacketFromIntegerId(net.PacketId(family, action)) -} + f.Func().Id("PacketFromId").Params( + jen.Id("family").Qual(types.PackagePath("net"), "PacketFamily"), + jen.Id("action").Qual(types.PackagePath("net"), "PacketAction"), + ).Params( + jen.Qual(types.PackagePath("net"), "Packet"), + jen.Error(), + ).Block( + jen.Return(jen.Id("PacketFromIntegerId").Call( + jen.Qual(types.PackagePath("net"), "PacketId").Call(jen.Id("family"), jen.Id("action")), + )), + ) -// PacketFromIntegerId creates a typed packet instance from a packet's ID. An ID may be converted from a family/action pair via the [net.PacketId] function. + f.Comment(`// PacketFromIntegerId creates a typed packet instance from a packet's ID. An ID may be converted from a family/action pair via the [net.PacketId] function. // The returned packet implements the [net.Packet] interface. It may be serialized/deserialized without further conversion, or a type assertion may be made to examine the data. The expected type of the assertion is a pointer to a packet structure. // The following example does both: an incoming CHAIR_REQUEST packet is deserialized from a reader without converting from the interface type, and the data is examined via a type assertion. // @@ -54,25 +72,37 @@ func PacketFromId(family net.PacketFamily, action net.PacketAction) (net.Packet, // } // default: // fmt.Printf("Unknown type: %s\n", reflect.TypeOf(pkt).Elem().Name()) -// } -func PacketFromIntegerId(id int) (net.Packet, error) { - packetType, idOk := packetMap[id] - if !idOk { - return nil, fmt.Errorf("could not find packet with id %d", id) - } +// }`) - packetInstance, typeOk := reflect.New(packetType).Interface().(net.Packet) - if !typeOk { - return nil, fmt.Errorf("could not create packet from id %d", id) - } + f.Func().Id("PacketFromIntegerId").Params( + jen.Id("id").Int(), // func declaration: int parameter 'id' + ).Params( + jen.Qual(types.PackagePath("net"), "Packet"), // func declaration: return types (net.Packet, error) + jen.Error(), + ).Block( + // try to get the packet type out of the map (indexed by the id) + jen.List(jen.Id("packetType"), jen.Id("idOk")).Op(":=").Id("packetMap").Index(jen.Id("id")), + // check that id is ok, return error otherwise + jen.If(jen.Op("!").Id("idOk")).Block( + jen.Return(jen.List(jen.Nil(), jen.Qual("fmt", "Errorf").Call(jen.Lit("could not find packet with id %d"), jen.Id("id")))), + ).Line(), + // type assert that creating the packet type results in an interface that satisfies net.Packet + jen.List(jen.Id("packetInstance"), jen.Id("typeOk").Op(":=").Qual("reflect", "New").Call( + jen.Id("packetType"), + ).Dot("Interface").Call().Assert( + jen.Qual(types.PackagePath("net"), "Packet"), + )), + // check that type is ok, return error otherwise + jen.If(jen.Op("!").Id("typeOk")).Block( + jen.Return(jen.List(jen.Nil(), jen.Qual("fmt", "Errorf").Call(jen.Lit("could not create packet from id %d"), jen.Id("id")))), + ).Line(), + // return packetInstance, nil + jen.Return(jen.Id("packetInstance"), jen.Nil()), + ) - return packetInstance, nil -} -`) - - if len(packets) > 0 { - const packetMapFileName = "packetmap_generated.go" - writeToFile(path.Join(outputDir, packetMapFileName), output.String()) + const packetMapFileName = "packetmap_generated.go" + if err := writeToFileJen(f, path.Join(outputDir, packetMapFileName)); err != nil { + return err } const packetFileName = "packets_generated.go" diff --git a/internal/codegen/shared.go b/internal/codegen/shared.go index a7503c7..b916cfa 100644 --- a/internal/codegen/shared.go +++ b/internal/codegen/shared.go @@ -5,14 +5,13 @@ import ( "io" "os" "path" - "strconv" "strings" "unicode" - "github.com/ethanmoffat/eolib-go/internal/xml" + "github.com/dave/jennifer/jen" ) -func getPackageName(outputDir string) (string, error) { +func getPackageStatement(outputDir string) (string, error) { packageFileName := path.Join(outputDir, "package.go") fp, err := os.Open(packageFileName) if err != nil { @@ -35,6 +34,22 @@ func getPackageName(outputDir string) (string, error) { return "", fmt.Errorf("package name not found in %s", outputDir) } +func getPackageName(outputDir string) (packageDeclaration string, err error) { + if packageDeclaration, err = getPackageStatement(outputDir); err != nil { + return + } + + split := strings.Split(packageDeclaration, " ") + if len(split) < 2 { + packageDeclaration = "" + err = fmt.Errorf("unable to determine package name from package declaration") + return + } + + packageDeclaration = split[1] + return +} + func sanitizeComment(comment string) string { split := strings.Split(comment, "\n") @@ -49,92 +64,25 @@ func sanitizeComment(comment string) string { return strings.Join(split, " ") } -func sanitizeTypeName(typeName string) string { - if strings.HasSuffix(typeName, "Type") { - return typeName[:len(typeName)-4] - } - return typeName -} - -func writeTypeComment(output *strings.Builder, typeName string, comment string) { +func writeTypeCommentJen(f *jen.File, typeName string, comment string) { if comment = sanitizeComment(comment); len(comment) > 0 { - output.WriteString(fmt.Sprintf("// %s :: %s\n", typeName, comment)) + f.Commentf("// %s :: %s", typeName, comment) } } -func writeInlineComment(output *strings.Builder, comment string) { +func writeInlineCommentJen(c jen.Code, comment string) { if comment = sanitizeComment(comment); len(comment) > 0 { - output.WriteString(fmt.Sprintf(" // %s", comment)) - } -} - -func writeToFile(outFileName string, outputText string) error { - ofp, err := os.Create(outFileName) - if err != nil { - return err - } - defer ofp.Close() - - n, err := ofp.Write([]byte(outputText)) - if err != nil { - return err - } - if n != len(outputText) { - return fmt.Errorf("wrote %d of %d bytes to file %s", n, len(outputText), outFileName) + switch v := c.(type) { + case *jen.Statement: + v.Comment(comment) + case *jen.Group: + v.Comment(comment) + } } - - return nil -} - -type importInfo struct { - Package string - Path string } -func eoTypeToGoType(eoType string, currentPackage string, fullSpec xml.Protocol) (goType string, nextImport *importInfo) { - if strings.ContainsRune(eoType, rune(':')) { - eoType = strings.Split(eoType, ":")[0] - } - - switch eoType { - case "byte": - fallthrough - case "char": - fallthrough - case "short": - fallthrough - case "three": - fallthrough - case "int": - return "int", nil - case "bool": - return "bool", nil - case "blob": - return "[]byte", nil - case "string": - fallthrough - case "encoded_string": - return "string", nil - default: - match := fullSpec.FindType(eoType) - - if structMatch, ok := match.(*xml.ProtocolStruct); ok { - if structMatch.Package != currentPackage { - goType = structMatch.Package + "." + eoType - nextImport = &importInfo{structMatch.Package, structMatch.PackagePath} - } else { - goType = eoType - } - } else if enumMatch, ok := match.(*xml.ProtocolEnum); ok { - if enumMatch.Package != currentPackage { - goType = enumMatch.Package + "." + eoType - nextImport = &importInfo{enumMatch.Package, enumMatch.PackagePath} - } else { - goType = eoType - } - } - return - } +func writeToFileJen(f *jen.File, outFileName string) error { + return f.Save(outFileName) } func snakeCaseToCamelCase(input string) string { @@ -174,98 +122,3 @@ func snakeCaseToPascalCase(input string) string { firstRune := []rune(camelCase)[0] return string(unicode.ToUpper(firstRune)) + camelCase[1:] } - -func getInstructionTypeName(inst xml.ProtocolInstruction) (typeName string, typeSize string) { - if inst.Type == nil { - return - } - - if strings.ContainsRune(*inst.Type, rune(':')) { - split := strings.Split(*inst.Type, ":") - typeName, typeSize = split[0], split[1] - } else { - typeName = *inst.Type - } - - return -} - -func calculateTypeSize(typeName string, fullSpec xml.Protocol) (res int, err error) { - var structInfo *xml.ProtocolStruct - var isStruct bool - if structInfo, isStruct = fullSpec.IsStruct(typeName); !isStruct { - return getPrimitizeTypeSize(typeName, fullSpec) - } - - var flattenedInstList []xml.ProtocolInstruction - for _, instruction := range (*structInfo).Instructions { - if instruction.XMLName.Local == "chunked" { - flattenedInstList = append(flattenedInstList, instruction.Chunked...) - } else { - flattenedInstList = append(flattenedInstList, instruction) - } - } - - for _, instruction := range flattenedInstList { - switch instruction.XMLName.Local { - case "field": - fieldTypeName, fieldTypeSize := getInstructionTypeName(instruction) - if fieldTypeSize != "" { - fieldTypeName = fieldTypeSize - } - - if instruction.Length != nil { - if length, err := strconv.ParseInt(*instruction.Length, 10, 32); err == nil { - // length is a numeric constant - res += int(length) - } else { - return 0, fmt.Errorf("instruction length %s must be a fixed size for %s (%s)", *instruction.Length, *instruction.Name, instruction.XMLName.Local) - } - } else { - if nestedSize, err := getPrimitizeTypeSize(fieldTypeName, fullSpec); err != nil { - return 0, err - } else { - res += nestedSize - } - } - case "break": - res += 1 - case "array": - case "dummy": - } - } - - return -} - -func getPrimitizeTypeSize(fieldTypeName string, fullSpec xml.Protocol) (int, error) { - switch fieldTypeName { - case "byte": - fallthrough - case "char": - return 1, nil - case "short": - return 2, nil - case "three": - return 3, nil - case "int": - return 4, nil - case "bool": - return 1, nil - case "blob": - fallthrough - case "string": - fallthrough - case "encoded_string": - return 0, fmt.Errorf("cannot get size of %s without fixed length", fieldTypeName) - default: - if _, isStruct := fullSpec.IsStruct(fieldTypeName); isStruct { - return calculateTypeSize(fieldTypeName, fullSpec) - } else if e, isEnum := fullSpec.IsEnum(fieldTypeName); isEnum { - enumTypeName := sanitizeTypeName(e.Type) - return getPrimitizeTypeSize(enumTypeName, fullSpec) - } else { - return 0, fmt.Errorf("cannot get fixed size of unrecognized type %s", fieldTypeName) - } - } -} diff --git a/internal/codegen/struct.go b/internal/codegen/struct.go index 1ca70f2..72021a7 100644 --- a/internal/codegen/struct.go +++ b/internal/codegen/struct.go @@ -4,9 +4,10 @@ import ( "fmt" "path" "strconv" - "strings" "unicode" + "github.com/dave/jennifer/jen" + "github.com/ethanmoffat/eolib-go/internal/codegen/types" "github.com/ethanmoffat/eolib-go/internal/xml" ) @@ -21,120 +22,95 @@ func GenerateStructs(outputDir string, structs []xml.ProtocolStruct, fullSpec xm } func generateStructsShared(outputDir string, outputFileName string, typeNames []string, fullSpec xml.Protocol) error { - packageDeclaration, err := getPackageName(outputDir) + packageName, err := getPackageName(outputDir) if err != nil { return err } - output := strings.Builder{} - output.WriteString(packageDeclaration + "\n\n") - var outputText string - if len(typeNames) > 0 { - output.WriteString("import (\n\t\"fmt\"\n\t\"github.com/ethanmoffat/eolib-go/pkg/eolib/data\"\n)\n\n// Ensure fmt import is referenced in generated code\nvar _ = fmt.Printf\n\n") + f := jen.NewFile(packageName) + types.AddImports(f) - var imports []importInfo + if len(typeNames) > 0 { for _, typeName := range typeNames { - if nextImports, err := writeStruct(&output, typeName, fullSpec); err != nil { + if err := writeStruct(f, typeName, fullSpec); err != nil { return err - } else { - imports = append(imports, nextImports...) - } - } - - outputText = output.String() - - var matches map[string]bool = make(map[string]bool) - var importText string - for _, imp := range imports { - if _, ok := matches[imp.Package]; !ok && strings.Split(packageDeclaration, " ")[1] != imp.Package { - importText = importText + fmt.Sprintf("\t%s \"github.com/ethanmoffat/eolib-go/pkg/eolib/protocol%s\"\n", imp.Package, imp.Path) - matches[imp.Package] = true } } - outputText = strings.ReplaceAll(outputText, "", importText) - } else { - outputText = output.String() } outFileName := path.Join(outputDir, outputFileName) - return writeToFile(outFileName, outputText) + return writeToFileJen(f, outFileName) } -func writeStruct(output *strings.Builder, typeName string, fullSpec xml.Protocol) (importPaths []importInfo, err error) { - var name string - var comment string - var instructions []xml.ProtocolInstruction - var packageName string - - var family string - var action string - - switchStructQualifier := "" - if structInfo, ok := fullSpec.IsStruct(typeName); ok { - name = structInfo.Name - comment = structInfo.Comment - instructions = structInfo.Instructions - packageName = structInfo.Package - } else if packetInfo, ok := fullSpec.IsPacket(typeName); ok { - name = packetInfo.GetTypeName() - comment = packetInfo.Comment - instructions = packetInfo.Instructions - packageName = packetInfo.Package - switchStructQualifier = packetInfo.Family + packetInfo.Action - family = packetInfo.Family - action = packetInfo.Action - } else { - return nil, fmt.Errorf("type %s is not a struct or packet in the spec", typeName) +func writeStruct(f *jen.File, typeName string, fullSpec xml.Protocol) (err error) { + var si *types.StructInfo + if si, err = types.GetStructInfo(typeName, fullSpec); err != nil { + return err } - structName := snakeCaseToPascalCase(name) - writeTypeComment(output, structName, comment) + err = writeStructShared(f, si, fullSpec) + return +} + +func writeStructShared(f *jen.File, si *types.StructInfo, fullSpec xml.Protocol) (err error) { + structName := snakeCaseToPascalCase(si.Name) + writeTypeCommentJen(f, structName, si.Comment) // write out fields - output.WriteString(fmt.Sprintf("type %s struct {\n", structName)) - switches, nextImports := writeStructFields(output, instructions, switchStructQualifier, packageName, fullSpec) - importPaths = append(importPaths, nextImports...) - output.WriteString("}\n\n") + var switches []*xml.ProtocolInstruction + f.Type().Id(structName).StructFunc(func(g *jen.Group) { + switches = writeStructFields(g, si, fullSpec) + }).Line() for _, sw := range switches { - if nextImports, err = writeSwitchStructs(output, *sw, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err + if err = writeSwitchStructs(f, *sw, si, fullSpec); err != nil { + return } - importPaths = append(importPaths, nextImports...) } - if len(family) > 0 && len(action) > 0 { + if len(si.Family) > 0 && len(si.Action) > 0 { // write out family/action methods - output.WriteString(fmt.Sprintf("func (s %s) Family() net.PacketFamily {\n\treturn net.PacketFamily_%s\n}\n\n", structName, family)) - output.WriteString(fmt.Sprintf("func (s %s) Action() net.PacketAction {\n\treturn net.PacketAction_%s\n}\n\n", structName, action)) + f.Func().Params(jen.Id("s").Id(structName)).Id("Family").Params().Qual(types.PackagePath("net"), "PacketFamily").Block( + jen.Return(jen.Qual(types.PackagePath("net"), fmt.Sprintf("PacketFamily_%s", si.Family))), + ).Line() + f.Func().Params(jen.Id("s").Id(structName)).Id("Action").Params().Qual(types.PackagePath("net"), "PacketAction").Block( + jen.Return(jen.Qual(types.PackagePath("net"), fmt.Sprintf("PacketAction_%s", si.Action))), + ).Line() } // write out serialize method - output.WriteString(fmt.Sprintf("func (s *%s) Serialize(writer *data.EoWriter) (err error) {\n", structName)) - output.WriteString("\toldSanitizeStrings := writer.SanitizeStrings\n") - output.WriteString("\tdefer func() {writer.SanitizeStrings = oldSanitizeStrings}()\n\n") - if nextImports, err = writeSerializeBody(output, instructions, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err + f.Func().Params(jen.Id("s").Op("*").Id(structName)).Id("Serialize").Params(jen.Id("writer").Op("*").Qual(types.PackagePath("data"), "EoWriter")).Params(jen.Id("err").Id("error")).BlockFunc(func(g *jen.Group) { + g.Id("oldSanitizeStrings").Op(":=").Id("writer").Dot("SanitizeStrings") + // defer here uses 'Values' instead of 'Block' so the deferred function is single-line style + g.Defer().Func().Params().Values(jen.Id("writer").Dot("SanitizeStrings").Op("=").Id("oldSanitizeStrings")).Call().Line() + + err = writeSerializeBody(g, si, fullSpec, nil) + + g.Return() + }).Line() + + if err != nil { + return } - importPaths = append(importPaths, nextImports...) - output.WriteString("\treturn\n") - output.WriteString("}\n\n") // write out deserialize method - output.WriteString(fmt.Sprintf("func (s *%s) Deserialize(reader *data.EoReader) (err error) {\n", structName)) - output.WriteString("\toldIsChunked := reader.IsChunked()\n") - output.WriteString("\tdefer func() { reader.SetIsChunked(oldIsChunked) }()\n\n") - if nextImports, err = writeDeserializeBody(output, instructions, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err - } - importPaths = append(importPaths, nextImports...) - output.WriteString("\n\treturn\n}\n\n") + f.Func().Params(jen.Id("s").Op("*").Id(structName)).Id("Deserialize").Params(jen.Id("reader").Op("*").Qual(types.PackagePath("data"), "EoReader")).Params(jen.Id("err").Id("error")).BlockFunc(func(g *jen.Group) { + g.Id("oldIsChunked").Op(":=").Id("reader").Dot("IsChunked").Call() + // defer here uses 'Values' instead of 'Block' so the deferred function is single-line style + g.Defer().Func().Params().Values(jen.Id("reader").Dot("SetIsChunked").Call(jen.Id("oldIsChunked"))).Call().Line() + + err = writeDeserializeBody(g, si, fullSpec, nil, false) + + g.Line().Return() + }).Line() return } -func writeStructFields(output *strings.Builder, instructions []xml.ProtocolInstruction, switchStructQualifier string, packageName string, fullSpec xml.Protocol) (switches []*xml.ProtocolInstruction, imports []importInfo) { - for i, inst := range instructions { +func writeStructFields(g *jen.Group, si *types.StructInfo, fullSpec xml.Protocol) (switches []*xml.ProtocolInstruction) { + isEmpty := true + + for i, inst := range si.Instructions { var instName string if inst.Name != nil { @@ -143,71 +119,91 @@ func writeStructFields(output *strings.Builder, instructions []xml.ProtocolInstr instName = snakeCaseToPascalCase(*inst.Field) } - var typeName string + var fieldTypeInfo struct { + typeName string + nextImport *types.ImportInfo + isPointer bool + } if inst.Type != nil { - var nextImport *importInfo - if typeName, nextImport = eoTypeToGoType(*inst.Type, packageName, fullSpec); nextImport != nil { - imports = append(imports, *nextImport) - } - + fieldTypeInfo.typeName, fieldTypeInfo.nextImport = types.ProtocolSpecTypeToGoType(*inst.Type, si.PackageName, fullSpec) if inst.Optional != nil && *inst.Optional { switch inst.XMLName.Local { - // these are the only supported values where the type needs to be modified to a pointer - // arrays also support the "optional" attribute in the spec but can be nil because they're defined as slices in the structs + // these are the only supported values where the type of the rendered field needs to be modified to a pointer + // arrays also support the "optional" attribute in the spec but default to nil since they are rendered as slices case "field": fallthrough case "length": - typeName = "*" + typeName + fieldTypeInfo.isPointer = true } } } + qualifiedTypeName := func(s *jen.Statement) { + if fieldTypeInfo.isPointer { + s.Op("*") + } + + writeComment := func(ss *jen.Statement) { + if inst.Comment != nil { + writeInlineCommentJen(ss, *inst.Comment) + } + } + + if fieldTypeInfo.nextImport != nil && fieldTypeInfo.nextImport.Package != si.PackageName { + s.Qual(fieldTypeInfo.nextImport.Path, fieldTypeInfo.typeName).Do(writeComment) + } else { + s.Id(fieldTypeInfo.typeName).Do(writeComment) + } + } + switch inst.XMLName.Local { case "field": if len(instName) > 0 { - output.WriteString(fmt.Sprintf("\t%s %s", instName, typeName)) + g.Id(instName).Do(qualifiedTypeName) + } else { + g.Line() } + isEmpty = false case "array": - output.WriteString(fmt.Sprintf("\t%s []%s", instName, typeName)) + g.Id(instName).Index().Do(qualifiedTypeName) + isEmpty = false case "length": - output.WriteString(fmt.Sprintf("\t%s %s", instName, typeName)) + g.Id(instName).Do(qualifiedTypeName) + isEmpty = false case "switch": - output.WriteString(fmt.Sprintf("\t%sData %s%sData", instName, switchStructQualifier, instName)) - switches = append(switches, &instructions[i]) + g.Id(fmt.Sprintf("%sData", instName)).Id(fmt.Sprintf("%s%sData", si.SwitchStructQualifier, instName)) + switches = append(switches, &si.Instructions[i]) + isEmpty = false case "chunked": - nextSwitches, nextImports := writeStructFields(output, inst.Chunked, switchStructQualifier, packageName, fullSpec) - switches = append(switches, nextSwitches...) - imports = append(imports, nextImports...) + nestedStructInfo, _ := si.Nested(&inst) + switches = append(switches, writeStructFields(g, nestedStructInfo, fullSpec)...) case "dummy": case "break": continue // no data to write } + } - if inst.Comment != nil { - writeInlineComment(output, *inst.Comment) - } - - output.WriteString("\n") + if isEmpty { + g.Line() } return } -func writeSwitchStructs(output *strings.Builder, switchInst xml.ProtocolInstruction, switchStructQualifier string, packageName string, fullSpec xml.Protocol) (imports []importInfo, err error) { +func writeSwitchStructs(f *jen.File, switchInst xml.ProtocolInstruction, si *types.StructInfo, fullSpec xml.Protocol) (err error) { if switchInst.XMLName.Local != "switch" { return } switchInterfaceName := fmt.Sprintf("%sData", snakeCaseToPascalCase(*switchInst.Field)) - if len(switchStructQualifier) > 0 { - switchInterfaceName = switchStructQualifier + switchInterfaceName + if len(si.SwitchStructQualifier) > 0 { + switchInterfaceName = si.SwitchStructQualifier + switchInterfaceName } if switchInst.Comment != nil { - writeTypeComment(output, switchInterfaceName, *switchInst.Comment) + writeTypeCommentJen(f, switchInterfaceName, *switchInst.Comment) } - - output.WriteString(fmt.Sprintf("type %s interface {\n\tprotocol.EoData\n}\n\n", switchInterfaceName)) + f.Type().Id(switchInterfaceName).Interface(jen.Qual(types.PackagePath("protocol"), "EoData")).Line() for _, c := range switchInst.Cases { if len(c.Instructions) == 0 { @@ -222,92 +218,56 @@ func writeSwitchStructs(output *strings.Builder, switchInst xml.ProtocolInstruct } caseStructName := fmt.Sprintf("%s%s", switchInterfaceName, caseName) - writeTypeComment(output, caseStructName, c.Comment) - - output.WriteString(fmt.Sprintf("type %s struct {\n", caseStructName)) - switches, nextImports := writeStructFields(output, c.Instructions, switchStructQualifier, packageName, fullSpec) - imports = append(imports, nextImports...) - output.WriteString("}\n\n") - - for _, sw := range switches { - if nextImports, err = writeSwitchStructs(output, *sw, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err - } - imports = append(imports, nextImports...) + nestedStructInfo := &types.StructInfo{ + Name: caseStructName, + Comment: c.Comment, + Instructions: c.Instructions, + PackageName: si.PackageName, + SwitchStructQualifier: si.SwitchStructQualifier, } - - // write out serialize method - output.WriteString(fmt.Sprintf("func (s *%s) Serialize(writer *data.EoWriter) (err error) {\n", caseStructName)) - output.WriteString("\toldSanitizeStrings := writer.SanitizeStrings\n") - output.WriteString("\tdefer func() {writer.SanitizeStrings = oldSanitizeStrings}()\n\n") - if nextImports, err = writeSerializeBody(output, c.Instructions, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err - } - imports = append(imports, nextImports...) - output.WriteString("\treturn\n") - output.WriteString("}\n\n") - - // write out deserialize method - output.WriteString(fmt.Sprintf("func (s *%s) Deserialize(reader *data.EoReader) (err error) {\n", caseStructName)) - output.WriteString("\toldIsChunked := reader.IsChunked()\n") - output.WriteString("\tdefer func() { reader.SetIsChunked(oldIsChunked) }()\n\n") - if nextImports, err = writeDeserializeBody(output, c.Instructions, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err + err = writeStructShared(f, nestedStructInfo, fullSpec) + if err != nil { + return } - imports = append(imports, nextImports...) - output.WriteString("\n\treturn\n}\n\n") } return } -// used to track the 'outer' list of instructions when a instruction is encountered -// this allows any nested instructions to search both the instructions in the section at the same level -// -// as well as the outer instructions in the or when determining the type of the switch field -var outerInstructionList []xml.ProtocolInstruction - -func writeSerializeBody(output *strings.Builder, instructionList []xml.ProtocolInstruction, switchStructQualifier string, packageName string, fullSpec xml.Protocol) (imports []importInfo, err error) { - for _, instruction := range instructionList { +func writeSerializeBody(g *jen.Group, si *types.StructInfo, fullSpec xml.Protocol, outerInstructionList []xml.ProtocolInstruction) (err error) { + for _, instruction := range si.Instructions { instructionType := instruction.XMLName.Local + instructionName := getInstructionName(instruction) - if instructionType == "chunked" { - output.WriteString("\twriter.SanitizeStrings = true\n") - oldOuterInstructionList := outerInstructionList - outerInstructionList = instructionList - defer func() { outerInstructionList = oldOuterInstructionList }() + switch instructionType { + case "chunked": + g.Id("writer").Dot("SanitizeStrings").Op("=").True() - if nextImports, err := writeSerializeBody(output, instruction.Chunked, switchStructQualifier, packageName, fullSpec); err != nil { - return nil, err - } else { - imports = append(imports, nextImports...) + var nestedInfo *types.StructInfo + if nestedInfo, err = si.Nested(&instruction); err != nil { + return } - output.WriteString("\twriter.SanitizeStrings = false\n") - continue - } - - if instructionType == "break" { - output.WriteString("\twriter.AddByte(255)\n") - continue - } - - instructionName := getInstructionName(instruction) + if err = writeSerializeBody(g, nestedInfo, fullSpec, si.Instructions); err != nil { + return + } - if instructionType == "switch" { + g.Id("writer").Dot("SanitizeStrings").Op("=").False() + case "break": + g.Id("writer").Dot("AddByte").Call(jen.Lit(0xFF)) + case "switch": // get type of Value field switchFieldSanitizedType := "" switchFieldEnumType := "" - for _, tmpInst := range append(outerInstructionList, instructionList...) { + for _, tmpInst := range append(outerInstructionList, si.Instructions...) { if tmpInst.XMLName.Local == "field" && snakeCaseToPascalCase(*tmpInst.Name) == instructionName { switchFieldEnumType = *tmpInst.Type - switchFieldSanitizedType = sanitizeTypeName(switchFieldEnumType) + switchFieldSanitizedType = types.SanitizeTypeName(switchFieldEnumType) break } } - output.WriteString(fmt.Sprintf("\tswitch s.%s {\n", instructionName)) - + var switchBlock []jen.Code for _, c := range instruction.Cases { if len(c.Instructions) == 0 { continue @@ -316,198 +276,225 @@ func writeSerializeBody(output *strings.Builder, instructionList []xml.ProtocolI var switchDataType string if c.Default { switchDataType = fmt.Sprintf("%sDataDefault", instructionName) - output.WriteString("\tdefault:\n") + switchBlock = append(switchBlock, jen.Default()) } else { switchDataType = fmt.Sprintf("%sData%s", instructionName, c.Value) - if _, err := strconv.ParseInt(c.Value, 10, 32); err != nil { + if value, err := strconv.ParseInt(c.Value, 10, 32); err != nil { // case is for an enum value if enumTypeInfo, ok := fullSpec.IsEnum(switchFieldEnumType); !ok { - return nil, fmt.Errorf("type %s in switch is not an enum", switchFieldEnumType) + return fmt.Errorf("type %s in switch is not an enum", switchFieldEnumType) } else { packageQualifier := "" - if enumTypeInfo.Package != packageName { - packageQualifier = enumTypeInfo.Package + "." - imports = append(imports, importInfo{enumTypeInfo.Package, enumTypeInfo.PackagePath}) + if enumTypeInfo.Package != si.PackageName { + packageQualifier = enumTypeInfo.Package } - output.WriteString(fmt.Sprintf("\tcase %s%s_%s:\n", packageQualifier, switchFieldSanitizedType, c.Value)) + switchBlock = append( + switchBlock, + jen.CaseFunc(func(g *jen.Group) { + if packageQualifier != "" { + g.Qual(types.PackagePath(packageQualifier), fmt.Sprintf("%s_%s", switchFieldSanitizedType, c.Value)) + } else { + g.Id(fmt.Sprintf("%s_%s", switchFieldSanitizedType, c.Value)) + } + }), + ) } } else { // case is for an integer constant - output.WriteString(fmt.Sprintf("\tcase %s:\n", c.Value)) + switchBlock = append(switchBlock, jen.Case(jen.Lit(int(value)))) } } - if len(switchDataType) > 0 { - output.WriteString(fmt.Sprintf("\t\tswitch s.%sData.(type) {\n", instructionName)) - output.WriteString(fmt.Sprintf("\t\tcase *%s%s:\n\t\t", switchStructQualifier, switchDataType)) - } - output.WriteString(fmt.Sprintf("\t\t\tif err = s.%sData.Serialize(writer); err != nil {\n", instructionName)) - output.WriteString("\t\t\t\treturn\n\t\t\t}\n") + // Serialize call for the case structure + caseSerialize := jen.If( + jen.Id("err").Op("=").Id("s").Dot(fmt.Sprintf("%sData", instructionName)).Dot("Serialize").Call(jen.Id("writer")), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()) if len(switchDataType) > 0 { - output.WriteString(fmt.Sprintf("\t\tdefault:\n\t\t\terr = fmt.Errorf(\"invalid switch struct type for switch value %%d\", s.%s)\n\t\t\treturn\n\t\t}\n", instructionName)) + // The object to serialize needs a type assertion + // Wrap it in a type assert switch that returns an error if it does not match + switchBlock = append( + switchBlock, + jen.Switch( + jen.Id("s").Dot( + fmt.Sprintf("%sData", instructionName), + ).Assert(jen.Id("type")).Block( + jen.Case( + jen.Op("*").Id(fmt.Sprintf("%s%s", si.SwitchStructQualifier, switchDataType)), + ).Block(caseSerialize), + jen.Default().Block( + jen.Id("err").Op("=").Qual("fmt", "Errorf").Call( + jen.Lit("invalid switch struct type for switch value %d"), + jen.Id("s").Dot(instructionName), + ).Line().Return(), + ), + ), + ), + ) + } else { + // The object to serialize does not need a type assertion + switchBlock = append(switchBlock, caseSerialize) } } - output.WriteString("\t}\n") - continue - } - - typeName, typeSize := getInstructionTypeName(instruction) - - instructionNameComment := instructionName - if len(instructionNameComment) == 0 && instruction.Content != nil { - instructionNameComment = *instruction.Content - } - output.WriteString(fmt.Sprintf("\t// %s : %s : %s\n", instructionNameComment, instructionType, *instruction.Type)) - - delimited := instruction.Delimited != nil && *instruction.Delimited - trailingDelimiter := instruction.TrailingDelimiter == nil || *instruction.TrailingDelimiter - if instructionType == "array" { - var lenExpr string - if instruction.Length != nil { - lenExpr = getLengthExpression(*instruction.Length) - } else { - lenExpr = fmt.Sprintf("len(s.%s)", instructionName) - } - - output.WriteString(fmt.Sprintf("\tfor ndx := 0; ndx < %s; ndx++ {\n\t\t", lenExpr)) + g.Switch(jen.Id("s").Dot(instructionName)).Block(switchBlock...) + default: + typeName, typeSize := types.GetInstructionTypeName(instruction) - if delimited && !trailingDelimiter { - output.WriteString("\t\tif ndx > 0 {\n\t\t\twriter.AddByte(255)\n\t\t}\n\n") + if len(instructionName) == 0 && instruction.Content != nil { + instructionName = *instruction.Content } - } - - switch typeName { - case "byte": - writeAddTypeForSerialize(output, instructionName, instruction, "Byte", false) - case "char": - writeAddTypeForSerialize(output, instructionName, instruction, "Char", false) - case "short": - writeAddTypeForSerialize(output, instructionName, instruction, "Short", false) - case "three": - writeAddTypeForSerialize(output, instructionName, instruction, "Three", false) - case "int": - writeAddTypeForSerialize(output, instructionName, instruction, "Int", false) - case "bool": - if len(typeSize) > 0 { - typeName = string(unicode.ToUpper(rune(typeSize[0]))) + typeSize[1:] - } else { - typeName = "Char" - } - output.WriteString(fmt.Sprintf("\tif s.%s {\n", instructionName)) - output.WriteString(fmt.Sprintf("\t\terr = writer.Add%s(1)\n\t} else {\n\t\terr = writer.Add%s(0)\n\t}\n", typeName, typeName)) - output.WriteString("\tif err != nil {\n\t\treturn\n\t}\n\n") - case "blob": - writeAddTypeForSerialize(output, instructionName, instruction, "Bytes", false) - case "string": - if instruction.Length != nil && instructionType == "field" { - if instruction.Padded != nil && *instruction.Padded { - writeAddStringTypeForSerialize(output, instructionName, instruction, "PaddedString") + g.Commentf("// %s : %s : %s", instructionName, instructionType, *instruction.Type) + + stringType := types.String + + var serializeCodes []jen.Code + switch typeName { + case "byte": + fallthrough + case "char": + fallthrough + case "short": + fallthrough + case "three": + fallthrough + case "int": + fallthrough + case "blob": + serializeCodes = getSerializeForInstruction(instruction, types.NewEoType(typeName), false) + case "bool": + if len(typeSize) > 0 { + typeName = string(unicode.ToUpper(rune(typeSize[0]))) + typeSize[1:] } else { - writeAddStringTypeForSerialize(output, instructionName, instruction, "FixedString") + typeName = "Char" } - } else { - writeAddStringTypeForSerialize(output, instructionName, instruction, "String") - } - case "encoded_string": - if instruction.Length != nil && instructionType == "field" { - if instruction.Padded != nil && *instruction.Padded { - writeAddStringTypeForSerialize(output, instructionName, instruction, "PaddedEncodedString") + serializeCodes = []jen.Code{ + jen.If(jen.Id("s").Dot(instructionName)).Block( + jen.Id("err").Op("=").Id("writer").Dot(fmt.Sprintf("Add%s", typeName)).Call(jen.Lit(1)), + ).Else().Block( + jen.Id("err").Op("=").Id("writer").Dot(fmt.Sprintf("Add%s", typeName)).Call(jen.Lit(0)), + ).Line(), + jen.If(jen.Id("err").Op("!=").Nil()).Block(jen.Return()).Line(), + } + case "encoded_string": + stringType = types.EncodedString + fallthrough + case "string": + if instruction.Length != nil && instructionType == "field" { + if instruction.Padded != nil && *instruction.Padded { + serializeCodes = getSerializeForInstruction(instruction, stringType+types.Padded, false) + } else { + serializeCodes = getSerializeForInstruction(instruction, stringType+types.Fixed, false) + } } else { - writeAddStringTypeForSerialize(output, instructionName, instruction, "FixedEncodedString") + serializeCodes = getSerializeForInstruction(instruction, stringType, false) + } + default: + if _, ok := fullSpec.IsStruct(typeName); ok { + serializeCodes = []jen.Code{ + jen.If( + jen.Id("err").Op("=").Id("s").Dot(instructionName).Do(func(s *jen.Statement) { + if instructionType == "array" { + s.Index(jen.Id("ndx")) + } + }).Dot("Serialize").Call(jen.Id("writer")), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()), + } + } else if e, ok := fullSpec.IsEnum(typeName); ok { + if t := types.NewEoType(e.Type); t&types.Primitive > 0 { + serializeCodes = getSerializeForInstruction(instruction, t, true) + } + } else { + err = fmt.Errorf("unable to find type '%s' when writing serialization function (member: %s, type: %s)", typeName, instructionName, instructionType) + return } - } else { - writeAddStringTypeForSerialize(output, instructionName, instruction, "EncodedString") } - default: - if _, ok := fullSpec.IsStruct(typeName); ok { - if instructionType == "array" { - instructionName = instructionName + "[ndx]" + + if instructionType == "array" { + var lenExpr *jen.Statement + if instruction.Length != nil { + lenExpr = getLengthExpression(*instruction.Length) + } else { + lenExpr = jen.Len(jen.Id("s").Dot(instructionName)) } - output.WriteString(fmt.Sprintf("\tif err = s.%s.Serialize(writer); err != nil {\n\t\treturn\n\t}\n", instructionName)) - } else if e, ok := fullSpec.IsEnum(typeName); ok { - switch e.Type { - case "byte": - fallthrough - case "char": - fallthrough - case "short": - fallthrough - case "three": - fallthrough - case "int": - writeAddTypeForSerialize(output, instructionName, instruction, string(unicode.ToUpper(rune(e.Type[0])))+e.Type[1:], true) + + delimited := instruction.Delimited != nil && *instruction.Delimited + trailingDelimiter := instruction.TrailingDelimiter == nil || *instruction.TrailingDelimiter + + if delimited { + addByteCode := jen.Id("writer").Dot("AddByte").Call(jen.Lit(0xFF)) + if !trailingDelimiter { + delimiterCode := jen.If( + jen.Id("ndx").Op(">").Lit(0).Block(addByteCode).Line(), + ) + serializeCodes = append([]jen.Code{delimiterCode}, serializeCodes...) + } else { + serializeCodes = append(serializeCodes, addByteCode) + } } - } else { - panic("Unable to find type '" + typeName + "' when writing serialization function") - } - } - if instructionType == "array" { - if delimited && trailingDelimiter { - output.WriteString("\t\twriter.AddByte(255)\n") + g.For( + jen.Id("ndx").Op(":=").Lit(0), + jen.Id("ndx").Op("<").Add(lenExpr), + jen.Id("ndx").Op("++"), + ).Block(serializeCodes...).Line() + } else { + g.Add(serializeCodes...) } - output.WriteString("\t}\n\n") } } return } -// flag that determines whether a chunked section is active or not -// this is used to determine if the next chunk should be selected in array delimiters and break bytes -var isChunked bool - -func writeDeserializeBody(output *strings.Builder, instructionList []xml.ProtocolInstruction, switchStructQualifier string, packageName string, fullSpec xml.Protocol) (imports []importInfo, err error) { - for _, instruction := range instructionList { +func writeDeserializeBody(g *jen.Group, si *types.StructInfo, fullSpec xml.Protocol, outerInstructionList []xml.ProtocolInstruction, isChunked bool) (err error) { + for _, instruction := range si.Instructions { instructionType := instruction.XMLName.Local + instructionName := getInstructionName(instruction) - if instructionType == "chunked" { - output.WriteString("\treader.SetIsChunked(true)\n") - oldChunked := isChunked - isChunked = true - oldOuterInstructionList := outerInstructionList - outerInstructionList = instructionList - defer func() { isChunked = oldChunked; outerInstructionList = oldOuterInstructionList }() - - nextImports, err := writeDeserializeBody(output, instruction.Chunked, switchStructQualifier, packageName, fullSpec) - if err != nil { - return nil, err + switch instructionType { + case "chunked": + g.Id("reader").Dot("SetIsChunked").Call(jen.True()) + + var nestedInfo *types.StructInfo + if nestedInfo, err = si.Nested(&instruction); err != nil { + return } - imports = append(imports, nextImports...) - output.WriteString("\treader.SetIsChunked(false)\n\n") - continue - } + if err = writeDeserializeBody(g, nestedInfo, fullSpec, si.Instructions, true); err != nil { + return + } - if instructionType == "break" { + g.Id("reader").Dot("SetIsChunked").Call(jen.False()) + case "break": if isChunked { - output.WriteString("\tif err = reader.NextChunk(); err != nil {\n\t\treturn\n\t}\n") + g.If( + jen.Id("err").Op("=").Id("reader").Dot("NextChunk").Call(), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()) } else { - output.WriteString("\tif breakByte := reader.GetByte(); breakByte != 255 {\n") - output.WriteString("\t\treturn fmt.Errorf(\"missing expected break byte\")\n") - output.WriteString("\t}\n") + g.If( + jen.Id("breakByte").Op(":=").Id("reader").Dot("GetByte").Call(), + jen.Id("breakByte").Op("!=").Lit(0xFF), + ).Block( + jen.Return(jen.Qual("fmt", "Errorf").Call(jen.Lit("missing expected break byte"))), + ) } - continue - } - - instructionName := getInstructionName(instruction) - - if instructionType == "switch" { + case "switch": // get type of Value field switchFieldSanitizedType := "" switchFieldEnumType := "" - for _, tmpInst := range append(outerInstructionList, instructionList...) { + for _, tmpInst := range append(outerInstructionList, si.Instructions...) { if tmpInst.XMLName.Local == "field" && snakeCaseToPascalCase(*tmpInst.Name) == instructionName { switchFieldEnumType = *tmpInst.Type - switchFieldSanitizedType = sanitizeTypeName(switchFieldEnumType) + switchFieldSanitizedType = types.SanitizeTypeName(switchFieldEnumType) break } } - output.WriteString(fmt.Sprintf("\tswitch s.%s {\n", instructionName)) - + var switchBlock []jen.Code for _, c := range instruction.Cases { if len(c.Instructions) == 0 { continue @@ -516,155 +503,193 @@ func writeDeserializeBody(output *strings.Builder, instructionList []xml.Protoco var switchDataType string if c.Default { switchDataType = fmt.Sprintf("%sDataDefault", instructionName) - output.WriteString("\tdefault:\n") + switchBlock = append(switchBlock, jen.Default()) } else { switchDataType = fmt.Sprintf("%sData%s", instructionName, c.Value) - if _, err := strconv.ParseInt(c.Value, 10, 32); err != nil { + if value, err := strconv.ParseInt(c.Value, 10, 32); err != nil { // case is for an enum value if enumTypeInfo, ok := fullSpec.IsEnum(switchFieldEnumType); !ok { - return nil, fmt.Errorf("type %s in switch is not an enum", switchFieldEnumType) + return fmt.Errorf("type %s in switch is not an enum", switchFieldEnumType) } else { packageQualifier := "" - if enumTypeInfo.Package != packageName { - packageQualifier = enumTypeInfo.Package + "." - imports = append(imports, importInfo{enumTypeInfo.Package, enumTypeInfo.PackagePath}) + if enumTypeInfo.Package != si.PackageName { + packageQualifier = enumTypeInfo.Package } - output.WriteString(fmt.Sprintf("\tcase %s%s_%s:\n", packageQualifier, switchFieldSanitizedType, c.Value)) + switchBlock = append(switchBlock, jen.CaseFunc(func(g *jen.Group) { + if packageQualifier != "" { + g.Qual(types.PackagePath(packageQualifier), fmt.Sprintf("%s_%s", switchFieldSanitizedType, c.Value)) + } else { + g.Id(fmt.Sprintf("%s_%s", switchFieldSanitizedType, c.Value)) + } + })) } } else { // case is for an integer constant - output.WriteString(fmt.Sprintf("\tcase %s:\n", c.Value)) + switchBlock = append(switchBlock, jen.Case(jen.Lit(int(value)))) } } - output.WriteString(fmt.Sprintf("\t\ts.%sData = &%s%s{}\n", instructionName, switchStructQualifier, switchDataType)) - output.WriteString(fmt.Sprintf("\t\tif err = s.%sData.Deserialize(reader); err != nil {\n", instructionName)) - output.WriteString("\t\t\treturn\n\t\t}\n") - } - - output.WriteString("\t}\n") + // Deserialize call for the case structure + sDotData := jen.Id("s").Dot(fmt.Sprintf("%sData", instructionName)) + caseDeserialize := sDotData.Clone().Op("=").Op("&").Id(si.SwitchStructQualifier + switchDataType).Block().Line() + caseDeserialize = caseDeserialize.If( + jen.Id("err").Op("=").Add(sDotData).Dot("Deserialize").Call(jen.Id("reader")), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()) - continue - } - - typeName, typeSize := getInstructionTypeName(instruction) - - instructionNameComment := instructionName - if len(instructionNameComment) == 0 && instruction.Content != nil { - instructionNameComment = *instruction.Content - } - output.WriteString(fmt.Sprintf("\t// %s : %s : %s\n", instructionNameComment, instructionType, *instruction.Type)) - - var lenExpr string - if instructionType == "array" { - if instruction.Length != nil { - lenExpr = "ndx < " + getLengthExpression(*instruction.Length) - } else if (instruction.Delimited == nil || !*instruction.Delimited) && isChunked { - rawLen, err := calculateTypeSize(typeName, fullSpec) - if err != nil { - lenExpr = "reader.Remaining() > 0" - } else { - lenExpr = "ndx < reader.Remaining() / " + strconv.Itoa(rawLen) - } - } else { - lenExpr = "reader.Remaining() > 0" + switchBlock = append(switchBlock, caseDeserialize) } - output.WriteString(fmt.Sprintf("\tfor ndx := 0; %s; ndx++ {\n\t\t", lenExpr)) - } + g.Switch(jen.Id("s").Dot(instructionName)).Block(switchBlock...) + default: + typeName, typeSize := types.GetInstructionTypeName(instruction) - switch typeName { - case "byte": - castType := "int" - writeGetTypeForDeserialize(output, instructionName, instruction, "Byte", &castType) - case "char": - writeGetTypeForDeserialize(output, instructionName, instruction, "Char", nil) - case "short": - writeGetTypeForDeserialize(output, instructionName, instruction, "Short", nil) - case "three": - writeGetTypeForDeserialize(output, instructionName, instruction, "Three", nil) - case "int": - writeGetTypeForDeserialize(output, instructionName, instruction, "Int", nil) - case "bool": - if len(typeSize) > 0 { - typeName = string(unicode.ToUpper(rune(typeSize[0]))) + typeSize[1:] - } else { - typeName = "Char" + if len(instructionName) == 0 && instruction.Content != nil { + instructionName = *instruction.Content } - output.WriteString(fmt.Sprintf("\tif boolVal := reader.Get%s(); boolVal > 0 {\n", typeName)) - output.WriteString(fmt.Sprintf("\t\ts.%s = true\n\t} else {\n\t\ts.%s = false\n\t}\n", instructionName, instructionName)) - case "blob": - writeGetTypeForDeserialize(output, instructionName, instruction, "Bytes", nil) - case "string": - if instruction.Length != nil && instructionType == "field" { - if instruction.Padded != nil && *instruction.Padded { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "PaddedString") + g.Commentf("// %s : %s : %s", instructionName, instructionType, *instruction.Type) + + stringType := types.String + + var deserializeCodes []jen.Code + switch typeName { + case "byte": + deserializeCodes = getDeserializeForInstruction(instruction, types.NewEoType(typeName), jen.Id("int")) + case "char": + fallthrough + case "short": + fallthrough + case "three": + fallthrough + case "int": + fallthrough + case "blob": + deserializeCodes = getDeserializeForInstruction(instruction, types.NewEoType(typeName), nil) + case "bool": + if len(typeSize) > 0 { + typeName = string(unicode.ToUpper(rune(typeSize[0]))) + typeSize[1:] } else { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "FixedString") + typeName = "Char" } - } else { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "String") - } - case "encoded_string": - if instruction.Length != nil && instructionType == "field" { - if instruction.Padded != nil && *instruction.Padded { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "PaddedEncodedString") + + deserializeCodes = []jen.Code{ + jen.If( + jen.Id("boolVal").Op(":=").Id("reader").Dot("Get"+typeName).Call(), + jen.Id("boolVal").Op(">").Lit(0), + ).Block( + jen.Id("s").Dot(instructionName).Op("=").True(), + ).Else().Block( + jen.Id("s").Dot(instructionName).Op("=").False(), + ), + } + case "encoded_string": + stringType = types.EncodedString + fallthrough + case "string": + if instruction.Length != nil && instructionType == "field" { + if instruction.Padded != nil && *instruction.Padded { + deserializeCodes = getDeserializeForInstruction(instruction, stringType+types.Padded, nil) + } else { + deserializeCodes = getDeserializeForInstruction(instruction, stringType+types.Fixed, nil) + } } else { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "FixedEncodedString") + deserializeCodes = getDeserializeForInstruction(instruction, stringType, nil) } - } else { - writeGetStringTypeForDeserialize(output, instructionName, instruction, "EncodedString") - } - default: - if structInfo, ok := fullSpec.IsStruct(typeName); ok { - if instructionType == "array" { - if packageName != structInfo.Package { - typeName = structInfo.Package + "." + typeName - imports = append(imports, importInfo{structInfo.Package, structInfo.PackagePath}) + default: + if s, ok := fullSpec.IsStruct(typeName); ok { + arrayCode := jen.Null() + if instructionType == "array" { + _, tp := types.ProtocolSpecTypeToGoType(s.Name, si.PackageName, fullSpec) + arrayCode = jen.Id("s").Dot(instructionName).Op("=").Append( + jen.Id("s").Dot(instructionName), + jen.Do(func(s *jen.Statement) { + if tp != nil { + s.Qual(tp.Path, typeName) + } else { + s.Id(typeName) + } + }).Block(), + ) } - output.WriteString(fmt.Sprintf("\ts.%s = append(s.%s, %s{})\n", instructionName, instructionName, typeName)) - instructionName = instructionName + "[ndx]" - } - output.WriteString(fmt.Sprintf("\tif err = s.%s.Deserialize(reader); err != nil {\n\t\treturn\n\t}\n", instructionName)) - } else if e, ok := fullSpec.IsEnum(typeName); ok { - switch e.Type { - case "byte": - fallthrough - case "char": - fallthrough - case "short": - fallthrough - case "three": - fallthrough - case "int": - if e.Package != packageName { - typeName = fmt.Sprintf("%s.%s", e.Package, typeName) + deserializeCodes = []jen.Code{ + arrayCode, + jen.If( + jen.Id("err").Op("=").Id("s").Dot(instructionName).Do(func(s *jen.Statement) { + if instructionType == "array" { + s.Index(jen.Id("ndx")) + } + }).Dot("Deserialize").Call(jen.Id("reader")), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()), } - writeGetTypeForDeserialize(output, instructionName, instruction, string(unicode.ToUpper(rune(e.Type[0])))+e.Type[1:], &typeName) + } else if e, ok := fullSpec.IsEnum(typeName); ok { + if eoType := types.NewEoType(e.Type); eoType&types.Primitive > 0 { + _, tp := types.ProtocolSpecTypeToGoType(e.Name, si.PackageName, fullSpec) + deserializeCodes = getDeserializeForInstruction( + instruction, + eoType, + jen.Do(func(s *jen.Statement) { + if tp != nil { + s.Qual(tp.Path, e.Name) + } else { + s.Id(e.Name) + } + }), + ) + } else { + err = fmt.Errorf("expected primitive base type for enum %s when writing deserialize function", e.Name) + } + } else { + panic("Unable to find type '" + typeName + "' when writing serialization function") } - imports = append(imports, importInfo{e.Package, e.PackagePath}) - } else { - panic("Unable to find type '" + typeName + "' when writing serialization function") } - } - delimited := instruction.Delimited != nil && *instruction.Delimited - trailingDelimiter := instruction.TrailingDelimiter == nil || *instruction.TrailingDelimiter - if instructionType == "array" { - if delimited && isChunked { - if !trailingDelimiter { - if instruction.Length == nil { - return nil, fmt.Errorf("delimited arrays with trailing-delimiter=false must have a length (array %s)", instructionName) + if instructionType == "array" { + delimited := instruction.Delimited != nil && *instruction.Delimited + + var lenExpr *jen.Statement + if instruction.Length != nil { + lenExpr = jen.Id("ndx").Op("<").Add(getLengthExpression(*instruction.Length)) + } else if !delimited && isChunked { + if rawLen, err := types.CalculateTypeSize(typeName, fullSpec); err != nil { + lenExpr = jen.Id("reader").Dot("Remaining").Call().Op(">").Lit(0) + } else { + lenExpr = jen.Id("ndx").Op("<").Id("reader").Dot("Remaining").Call().Op("/").Lit(rawLen) } - output.WriteString(fmt.Sprintf("\t\tif ndx + 1 < %s {\n", getLengthExpression(*instruction.Length))) + } else { + lenExpr = jen.Id("reader").Dot("Remaining").Call().Op(">").Lit(0) } - output.WriteString("\t\tif err = reader.NextChunk(); err != nil {\n\t\t\treturn\n\t\t}\n") - if !trailingDelimiter { - output.WriteString("\t\t}\n") + + trailingDelimiter := instruction.TrailingDelimiter == nil || *instruction.TrailingDelimiter + + if delimited && isChunked { + delimiterExpr := jen.If( + jen.Id("err").Op("=").Id("reader").Dot("NextChunk").Call(), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()) + + if !trailingDelimiter { + if instruction.Length == nil { + err = fmt.Errorf("delimited arrays with trailing-delimiter=false must have a length (array %s)", instructionName) + return + } + + delimiterExpr = jen.If( + jen.Id("ndx").Op("+").Lit(1).Op("<").Add(getLengthExpression(*instruction.Length))).Block(delimiterExpr) + } + + deserializeCodes = append(deserializeCodes, delimiterExpr) } + + g.For( + jen.Id("ndx").Op(":=").Lit(0), + lenExpr, + jen.Id("ndx").Op("++"), + ).Block(deserializeCodes...).Line() + } else { + g.Add(deserializeCodes...) } - output.WriteString("\t}\n\n") } } @@ -680,163 +705,180 @@ func getInstructionName(inst xml.ProtocolInstruction) (instName string) { return } -func writeAddTypeForSerialize(output *strings.Builder, instructionName string, instruction xml.ProtocolInstruction, methodType string, needsCastToInt bool) { - optional := instruction.Optional != nil && *instruction.Optional +func getSerializeForInstruction(instruction xml.ProtocolInstruction, methodType types.EoType, needsCastToInt bool) []jen.Code { + instructionName := getInstructionName(instruction) + + // the method type is a string if it has the eotype_str or eotype_str_encoded flag + isString := (methodType&types.String) > 0 || (methodType&types.EncodedString) > 0 + var instructionCode, nilCheckCode *jen.Statement if len(instructionName) == 0 && instruction.Content != nil { - instructionName = *instruction.Content + if isString { + instructionCode = jen.Lit(*instruction.Content) + } else { + instructionCode = jen.Id(*instruction.Content) + } } else { - instructionName = "s." + instructionName + instructionCode = jen.Id("s").Dot(instructionName) } + isArray := false + optional := instruction.Optional != nil && *instruction.Optional if instruction.XMLName.Local == "array" { - instructionName = instructionName + "[ndx]" + instructionCode = instructionCode.Index(jen.Id("ndx")) // optional arrays that are unset will be nil. // The length expression in the loop checks the length of the nil slice, which evaluates to 0. // This means that arrays do not need additional dereferencing when optional. optional = false - } else if optional { - output.WriteString(fmt.Sprintf("\tif %s != nil {\n", instructionName)) - instructionName = "*" + instructionName + isArray = true } - if needsCastToInt { - instructionName = "int(" + instructionName + ")" - } - - output.WriteString(fmt.Sprintf("\t\tif err = writer.Add%s(%s); err != nil {\n\t\t\treturn\n\t\t}\n", methodType, instructionName)) - if optional { - output.WriteString("\t}\n") - } -} - -func writeGetTypeForDeserialize(output *strings.Builder, instructionName string, instruction xml.ProtocolInstruction, methodType string, castType *string) { - optional := instruction.Optional != nil && *instruction.Optional - - lengthExpr := "" - if instruction.XMLName.Local != "array" { - if instruction.Length != nil { - lengthExpr = getLengthExpression(*instruction.Length) - } else if methodType == "Bytes" { - lengthExpr = "reader.Remaining()" - } - } else { - // optional arrays that are unset will be nil. - // The length expression in the loop checks the length of the nil slice, which evaluates to 0. - // This means that arrays do not need additional dereferencing when optional. - optional = false + nilCheckCode = instructionCode.Clone() + instructionCode = jen.Op("*").Add(instructionCode) } - if optional { - output.WriteString("\tif reader.Remaining() > 0 {\n") + if needsCastToInt { + instructionCode = jen.Int().Call(instructionCode) } - if len(instructionName) == 0 && instruction.Content != nil { - output.WriteString(fmt.Sprintf("\treader.Get%s(%s)\n", methodType, lengthExpr)) - } else { - if instruction.XMLName.Local == "array" { - output.WriteString(fmt.Sprintf("\t\ts.%s = append(s.%s, 0)\n", instructionName, instructionName)) - instructionName = instructionName + "[ndx]" - } - - if castType != nil { - if optional { - output.WriteString(fmt.Sprintf("\t\ts.%s = new(%s)\n\t\t*s.", instructionName, *castType)) - } else { - output.WriteString("\t\ts.") - } + serializeCode := jen.If( + jen.Id("err").Op("=").Id("writer").Dot("Add"+methodType.String()).Call( + instructionCode, + jen.Do(func(s *jen.Statement) { + // strings may have a fixed length that needs to be serialized + if !isArray && isString && instruction.Length != nil { + s.Add(getLengthExpression(*instruction.Length)) + } + }), + ), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()) - output.WriteString(fmt.Sprintf("%s = %s(reader.Get%s(%s))\n", instructionName, *castType, methodType, lengthExpr)) - } else { + return []jen.Code{ + jen.Do(func(s *jen.Statement) { if optional { - output.WriteString(fmt.Sprintf("\t\ts.%s = new(int)\n\t\t*s.", instructionName)) + s.If(nilCheckCode.Op("!=").Nil()).Block(serializeCode) } else { - output.WriteString("\t\ts.") + s.Add(serializeCode) } - - output.WriteString(fmt.Sprintf("%s = reader.Get%s(%s)\n", instructionName, methodType, lengthExpr)) - } - } - - if optional { - output.WriteString("\t}\n") + }), } } -func writeAddStringTypeForSerialize(output *strings.Builder, instructionName string, instruction xml.ProtocolInstruction, methodType string) { - optional := instruction.Optional != nil && *instruction.Optional - - if len(instructionName) == 0 && instruction.Content != nil { - instructionName = `"` + *instruction.Content + `"` - } else { - instructionName = "s." + instructionName - } - - if instruction.XMLName.Local == "array" { - instructionName = instructionName + "[ndx]" - optional = false - } else if instruction.Length != nil { - instructionName = instructionName + ", " + getLengthExpression(*instruction.Length) - } - - if optional { - output.WriteString(fmt.Sprintf("\tif %s != nil {\n", instructionName)) - instructionName = "*" + instructionName - } +func getDeserializeForInstruction(instruction xml.ProtocolInstruction, methodType types.EoType, castType *jen.Statement) []jen.Code { + instructionName := getInstructionName(instruction) - output.WriteString(fmt.Sprintf("\t\tif err = writer.Add%s(%s); err != nil {\n\t\t\treturn\n\t\t}\n", methodType, instructionName)) + // the method type is a string if it has the eotype_str or eotype_str_encoded flag + isString := (methodType&types.String) > 0 || (methodType&types.EncodedString) > 0 - if optional { - output.WriteString("\t}\n") - } -} - -func writeGetStringTypeForDeserialize(output *strings.Builder, instructionName string, instruction xml.ProtocolInstruction, methodType string) { + isArray := false optional := instruction.Optional != nil && *instruction.Optional - lengthExpr := "" + lengthExpr := jen.Null() if instruction.XMLName.Local != "array" { if instruction.Length != nil { lengthExpr = getLengthExpression(*instruction.Length) + } else if methodType == types.Bytes { + lengthExpr = jen.Id("reader").Dot("Remaining").Call() } } else { + // optional arrays that are unset will be nil. + // The length expression in the loop checks the length of the nil slice, which evaluates to 0. + // This means that arrays do not need additional dereferencing when optional. optional = false + isArray = true } - if optional { - output.WriteString("\tif reader.Remaining() > 0 {\n") - } + readerGetCode := jen.Id("reader").Dot("Get" + methodType.String()).Call(lengthExpr) + var retCodes []jen.Code + var assignRHS, assignLHS *jen.Statement + hasAssignTarget := false if len(instructionName) == 0 && instruction.Content != nil { - output.WriteString(fmt.Sprintf("\tif _, err = reader.Get%s(%s); err != nil {\n\t\treturn\n\t}\n", methodType, lengthExpr)) + if isString { + assignRHS = jen.Op("=").Add(readerGetCode) + assignLHS = jen.Id("_") + } else { + assignRHS = jen.Add(readerGetCode) + assignLHS = jen.Null() + } } else { - if instruction.XMLName.Local == "array" { - output.WriteString(fmt.Sprintf("\t\ts.%s = append(s.%s, \"\")\n", instructionName, instructionName)) - instructionName = instructionName + "[ndx]" + hasAssignTarget = true + + indexCode := jen.Null() + if isArray { + // pre-append an item to the array in the struct field + var defaultCode *jen.Statement + if isString { + defaultCode = jen.Lit("") + } else { + defaultCode = jen.Lit(0) + } + + retCodes = append(retCodes, jen.Id("s").Dot(instructionName).Op("=").Append(jen.Id("s").Dot(instructionName), defaultCode)) + indexCode = jen.Index(jen.Id("ndx")) } if optional { - output.WriteString(fmt.Sprintf("\t\ts.%s = new(string)\n\t\tif *s.", instructionName)) + // instantiate the optional struct field + retCodes = append(retCodes, jen.Id("s").Dot(instructionName).Op("=").New(jen.Do(func(s *jen.Statement) { + if castType != nil { + s.Add(castType) + } else if isString { + s.String() + } else { + s.Int() + } + }))) + + assignLHS = jen.Op("*").Id("s").Dot(instructionName).Add(indexCode) } else { - output.WriteString("\t\tif s.") + assignLHS = jen.Id("s").Dot(instructionName).Add(indexCode) } - output.WriteString(fmt.Sprintf("%s, err = reader.Get%s(%s); err != nil {\n\t\treturn\n\t}\n\n", instructionName, methodType, lengthExpr)) + assignRHS = jen.Op("=").Do(func(s *jen.Statement) { + if castType != nil { + s.Add(castType).Call(readerGetCode) + } else { + s.Add(readerGetCode) + } + }) + } + + var assignBlock *jen.Statement + if isString { + assignBlock = jen.If( + jen.List(assignLHS, jen.Id("err")).Add(assignRHS), + jen.Id("err").Op("!=").Nil(), + ).Block(jen.Return()).Do(func(s *jen.Statement) { + // _, err := strconv.ParseInt(*instruction.Length, 10, 32) + if hasAssignTarget { + // For compatibility: prior codegen inserted an extra newline after fixed strings that referenced a length field + s.Line() + } + }) + } else { + assignBlock = assignLHS.Add(assignRHS) } if optional { - output.WriteString("\t}\n") + retCodes = append(retCodes, assignBlock) + retCodes = []jen.Code{jen.If(jen.Id("reader").Dot("Remaining").Call().Op(">").Lit(0)).Block(retCodes...)} + } else { + retCodes = append(retCodes, assignBlock) } + + return retCodes } -func getLengthExpression(instLength string) string { - if _, err := strconv.ParseInt(instLength, 10, 32); err == nil { +func getLengthExpression(instLength string) *jen.Statement { + if parsed, err := strconv.ParseInt(instLength, 10, 32); err == nil { // string length is a numeric constant - return instLength + return jen.Lit(int(parsed)) } else { // string length is a reference to another field - return "s." + snakeCaseToPascalCase(instLength) + return jen.Id("s").Dot(snakeCaseToPascalCase(instLength)) } } diff --git a/internal/codegen/types/aliases.go b/internal/codegen/types/aliases.go new file mode 100644 index 0000000..1c64cc6 --- /dev/null +++ b/internal/codegen/types/aliases.go @@ -0,0 +1,25 @@ +package types + +import "github.com/dave/jennifer/jen" + +// packageAliases is a map of package short names to package paths. For use with Jennifer. +var packageAliases = map[string]string{ + "data": "github.com/ethanmoffat/eolib-go/pkg/eolib/data", + "net": "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net", + "protocol": "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol", + "pub": "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/pub", +} + +func PackagePath(packageName string) string { + if v, ok := packageAliases[packageName]; ok { + return v + } + + return packageName +} + +func AddImports(f *jen.File) { + for k, v := range packageAliases { + f.ImportName(v, k) + } +} diff --git a/internal/codegen/types/eotype.go b/internal/codegen/types/eotype.go new file mode 100644 index 0000000..58101ea --- /dev/null +++ b/internal/codegen/types/eotype.go @@ -0,0 +1,93 @@ +package types + +type EoType int + +const ( + Invalid EoType = 0 +) + +const ( + Primitive EoType = iota + 0x0100 // flag indicating type is a primitive (supported for bool) + Byte + Char + Short + Three + Int + Bool +) + +const ( + Complex EoType = iota + 0x0200 // flag indicating type is complex (not supported for bool) + Bytes +) + +const ( + String EoType = iota + 0x0400 // flag indicating type is a string type + PaddedString + FixedString +) + +const ( + EncodedString EoType = iota + 0x0800 // flag indicating type is an encoded string type + PaddedEncodedString + FixedEncodedString +) + +// offsets from String or EncodedString to the other string method types +const ( + _ EoType = iota + Padded + Fixed +) + +func (t EoType) String() string { + switch t { + case Byte: + return "Byte" + case Char: + return "Char" + case Short: + return "Short" + case Three: + return "Three" + case Int: + return "Int" + case Bool: + return "Byte" + case Bytes: + return "Bytes" + case String: + return "String" + case PaddedString: + return "PaddedString" + case FixedString: + return "FixedString" + case EncodedString: + return "EncodedString" + case PaddedEncodedString: + return "PaddedEncodedString" + case FixedEncodedString: + return "FixedEncodedString" + } + + return "" +} + +func NewEoType(str string) EoType { + switch str { + case "byte": + return Byte + case "char": + return Char + case "short": + return Short + case "three": + return Three + case "int": + return Int + case "blob": + return Bytes + } + + return Invalid +} diff --git a/internal/codegen/types/structinfo.go b/internal/codegen/types/structinfo.go new file mode 100644 index 0000000..5bb7d20 --- /dev/null +++ b/internal/codegen/types/structinfo.go @@ -0,0 +1,62 @@ +package types + +import ( + "errors" + "fmt" + + "github.com/ethanmoffat/eolib-go/internal/xml" +) + +// StructInfo is a type representing the metadata about a struct that should be rendered as generated code. +// It represents the common properties of either a ProtocolPacket or a ProtocolStruct. +type StructInfo struct { + Name string // Name is the name of the type. It is not converted from protocol naming convention (snake_case). + Comment string // Comment is an optional type comment for the struct. + Instructions []xml.ProtocolInstruction // Instructions is a collection of instructions for the struct. + PackageName string // PackageName is the containing package name for the struct. + + Family string // Family is the Packet Family of the struct, if the struct is a packet struct. + Action string // Action is the Packet Action of the struct, if the struct is a packet struct. + SwitchStructQualifier string // SwitchStructQualifier is an additional qualifier prepended to structs used in switch cases in packets. +} + +// GetStructInfo generates a [StructInfo] for the specified typeName. typeName can be a structure +// or a packet in the full XML spec. +func GetStructInfo(typeName string, fullSpec xml.Protocol) (si *StructInfo, err error) { + si = &StructInfo{SwitchStructQualifier: ""} + err = nil + + if structInfo, ok := fullSpec.IsStruct(typeName); ok { + si.Name = structInfo.Name + si.Comment = structInfo.Comment + si.Instructions = structInfo.Instructions + si.PackageName = structInfo.Package + } else if packetInfo, ok := fullSpec.IsPacket(typeName); ok { + si.Name = packetInfo.GetTypeName() + si.Comment = packetInfo.Comment + si.Instructions = packetInfo.Instructions + si.PackageName = packetInfo.Package + si.SwitchStructQualifier = packetInfo.Family + packetInfo.Action + si.Family = packetInfo.Family + si.Action = packetInfo.Action + } else { + si = nil + err = fmt.Errorf("type %s is not a struct or packet in the spec", typeName) + } + + return +} + +// Nested creates a nested [StructInfo] from the specified 'chunked' [xml.ProtocolInstruction]. +// This function returns an error if the instruction is not of the 'chunked' type. +func (si *StructInfo) Nested(chunked *xml.ProtocolInstruction) (*StructInfo, error) { + if chunked.XMLName.Local != "chunked" { + return nil, errors.New("expected 'chunked' instruction creating nested StructInfo") + } + + return &StructInfo{ + Instructions: chunked.Chunked, + PackageName: si.PackageName, + SwitchStructQualifier: si.SwitchStructQualifier, + }, nil +} diff --git a/internal/codegen/types/typeconv.go b/internal/codegen/types/typeconv.go new file mode 100644 index 0000000..7969acc --- /dev/null +++ b/internal/codegen/types/typeconv.go @@ -0,0 +1,56 @@ +package types + +import ( + "strings" + + "github.com/ethanmoffat/eolib-go/internal/xml" +) + +type ImportInfo struct { + Package string + Path string +} + +func ProtocolSpecTypeToGoType(eoType string, currentPackage string, fullSpec xml.Protocol) (goType string, nextImport *ImportInfo) { + if strings.ContainsRune(eoType, rune(':')) { + eoType = strings.Split(eoType, ":")[0] + } + + switch eoType { + case "byte": + fallthrough + case "char": + fallthrough + case "short": + fallthrough + case "three": + fallthrough + case "int": + return "int", nil + case "bool": + return "bool", nil + case "blob": + return "[]byte", nil + case "string": + fallthrough + case "encoded_string": + return "string", nil + default: + match := fullSpec.FindType(eoType) + goType = eoType + + if structMatch, ok := match.(*xml.ProtocolStruct); ok && structMatch.Package != currentPackage { + nextImport = &ImportInfo{structMatch.Package, structMatch.PackagePath} + } else if enumMatch, ok := match.(*xml.ProtocolEnum); ok && enumMatch.Package != currentPackage { + nextImport = &ImportInfo{enumMatch.Package, enumMatch.PackagePath} + } + + if nextImport != nil { + if val, ok := packageAliases[nextImport.Package]; ok { + nextImport.Path = val + } + } + + return + } +} diff --git a/internal/codegen/types/typesize.go b/internal/codegen/types/typesize.go new file mode 100644 index 0000000..efbe046 --- /dev/null +++ b/internal/codegen/types/typesize.go @@ -0,0 +1,115 @@ +package types + +import ( + "fmt" + "strconv" + "strings" + + "github.com/ethanmoffat/eolib-go/internal/xml" +) + +// SanitizeTypeName sanitizes the type name for serialization. Effectively, this removes the 'Type' +// suffix if present. +func SanitizeTypeName(typeName string) string { + if strings.HasSuffix(typeName, "Type") { + return typeName[:len(typeName)-4] + } + return typeName +} + +// GetInstructionTypeName gets the type name (and byte size, if present) of an instruction. +func GetInstructionTypeName(inst xml.ProtocolInstruction) (typeName string, typeSize string) { + if inst.Type == nil { + return + } + + if strings.ContainsRune(*inst.Type, rune(':')) { + split := strings.Split(*inst.Type, ":") + typeName, typeSize = split[0], split[1] + } else { + typeName = *inst.Type + } + + return +} + +// CalculateTypeSize gets the size of the named type by recursively evaluating and summing the size +// of the named type's members +func CalculateTypeSize(typeName string, fullSpec xml.Protocol) (res int, err error) { + structInfo, isStruct := fullSpec.IsStruct(typeName) + if !isStruct { + return getPrimitiveTypeSize(typeName, fullSpec) + } + + var flattenedInstList []xml.ProtocolInstruction + for _, instruction := range (*structInfo).Instructions { + if instruction.XMLName.Local == "chunked" { + flattenedInstList = append(flattenedInstList, instruction.Chunked...) + } else { + flattenedInstList = append(flattenedInstList, instruction) + } + } + + for _, instruction := range flattenedInstList { + switch instruction.XMLName.Local { + case "field": + fieldTypeName, fieldTypeSize := GetInstructionTypeName(instruction) + if fieldTypeSize != "" { + fieldTypeName = fieldTypeSize + } + + if instruction.Length != nil { + if length, err := strconv.ParseInt(*instruction.Length, 10, 32); err == nil { + // length is a numeric constant + res += int(length) + } else { + return 0, fmt.Errorf("instruction length %s must be a fixed size for %s (%s)", *instruction.Length, *instruction.Name, instruction.XMLName.Local) + } + } else { + if nestedSize, err := getPrimitiveTypeSize(fieldTypeName, fullSpec); err != nil { + return 0, err + } else { + res += nestedSize + } + } + case "break": + res += 1 + case "array": + case "dummy": + } + } + + return +} + +func getPrimitiveTypeSize(fieldTypeName string, fullSpec xml.Protocol) (int, error) { + switch fieldTypeName { + case "byte": + fallthrough + case "char": + return 1, nil + case "short": + return 2, nil + case "three": + return 3, nil + case "int": + return 4, nil + case "bool": + return 1, nil + case "blob": + fallthrough + case "string": + fallthrough + case "encoded_string": + return 0, fmt.Errorf("cannot get size of %s without fixed length", fieldTypeName) + default: + if _, isStruct := fullSpec.IsStruct(fieldTypeName); isStruct { + return CalculateTypeSize(fieldTypeName, fullSpec) + } else if e, isEnum := fullSpec.IsEnum(fieldTypeName); isEnum { + enumTypeName := SanitizeTypeName(e.Type) + return getPrimitiveTypeSize(enumTypeName, fullSpec) + } else { + return 0, fmt.Errorf("cannot get fixed size of unrecognized type %s", fieldTypeName) + } + } +} diff --git a/pkg/eolib/data/writer.go b/pkg/eolib/data/writer.go index 6dc5a2f..26beb7a 100644 --- a/pkg/eolib/data/writer.go +++ b/pkg/eolib/data/writer.go @@ -31,7 +31,7 @@ func (w *EoWriter) Write(p []byte) (int, error) { return len(p), nil } -// AddInt adds a raw byte to the writer data. +// AddByte adds a raw byte to the writer data. func (w *EoWriter) AddByte(value int) error { if value > 0xFF { return errors.New("value is larger than maximum raw byte size") diff --git a/pkg/eolib/protocol/map/structs_generated.go b/pkg/eolib/protocol/map/structs_generated.go index 23cd96c..dec5e81 100644 --- a/pkg/eolib/protocol/map/structs_generated.go +++ b/pkg/eolib/protocol/map/structs_generated.go @@ -1,14 +1,10 @@ package eomap import ( - "fmt" "github.com/ethanmoffat/eolib-go/pkg/eolib/data" - protocol "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" ) -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf - // MapNpc :: NPC spawn EMF entity. type MapNpc struct { Coords protocol.Coords diff --git a/pkg/eolib/protocol/net/client/packets_generated.go b/pkg/eolib/protocol/net/client/packets_generated.go index a7dd538..86a9a8a 100644 --- a/pkg/eolib/protocol/net/client/packets_generated.go +++ b/pkg/eolib/protocol/net/client/packets_generated.go @@ -3,13 +3,10 @@ package client import ( "fmt" "github.com/ethanmoffat/eolib-go/pkg/eolib/data" - protocol "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" - net "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" ) -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf - // InitInitClientPacket :: Connection initialization request. This packet is unencrypted. type InitInitClientPacket struct { Challenge int diff --git a/pkg/eolib/protocol/net/client/structs_generated.go b/pkg/eolib/protocol/net/client/structs_generated.go index 6e3a093..a14d0ef 100644 --- a/pkg/eolib/protocol/net/client/structs_generated.go +++ b/pkg/eolib/protocol/net/client/structs_generated.go @@ -1,14 +1,10 @@ package client import ( - "fmt" "github.com/ethanmoffat/eolib-go/pkg/eolib/data" - protocol "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" ) -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf - // ByteCoords :: Map coordinates with raw 1-byte values. type ByteCoords struct { X int diff --git a/pkg/eolib/protocol/net/server/packets_generated.go b/pkg/eolib/protocol/net/server/packets_generated.go index c683fde..a581cd2 100644 --- a/pkg/eolib/protocol/net/server/packets_generated.go +++ b/pkg/eolib/protocol/net/server/packets_generated.go @@ -3,14 +3,11 @@ package server import ( "fmt" "github.com/ethanmoffat/eolib-go/pkg/eolib/data" - protocol "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" - net "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" - pub "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/pub" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/pub" ) -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf - // InitInitServerPacket :: Reply to connection initialization and requests for unencrypted data. This packet is unencrypted. type InitInitServerPacket struct { ReplyCode InitReply @@ -8104,7 +8101,6 @@ func (s *ChestSpecServerPacket) Deserialize(reader *data.EoReader) (err error) { // ChestCloseServerPacket :: Reply to trying to interact with a locked or "broken" chest. The official client assumes a broken chest if the packet is under 2 bytes in length. type ChestCloseServerPacket struct { Key *int // Sent if the player is trying to interact with a locked chest. - } func (s ChestCloseServerPacket) Family() net.PacketFamily { diff --git a/pkg/eolib/protocol/net/server/structs_generated.go b/pkg/eolib/protocol/net/server/structs_generated.go index cc8575b..818ffd2 100644 --- a/pkg/eolib/protocol/net/server/structs_generated.go +++ b/pkg/eolib/protocol/net/server/structs_generated.go @@ -3,13 +3,10 @@ package server import ( "fmt" "github.com/ethanmoffat/eolib-go/pkg/eolib/data" - protocol "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" - net "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol" + "github.com/ethanmoffat/eolib-go/pkg/eolib/protocol/net" ) -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf - // BigCoords :: Map coordinates with 2-byte values. type BigCoords struct { X int diff --git a/pkg/eolib/protocol/net/structs_generated.go b/pkg/eolib/protocol/net/structs_generated.go index ec7e5fb..722fe81 100644 --- a/pkg/eolib/protocol/net/structs_generated.go +++ b/pkg/eolib/protocol/net/structs_generated.go @@ -1,12 +1,6 @@ package net -import ( - "fmt" - "github.com/ethanmoffat/eolib-go/pkg/eolib/data" -) - -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf +import "github.com/ethanmoffat/eolib-go/pkg/eolib/data" // Version :: Client version. type Version struct { diff --git a/pkg/eolib/protocol/pub/structs_generated.go b/pkg/eolib/protocol/pub/structs_generated.go index b3049a3..0abbd94 100644 --- a/pkg/eolib/protocol/pub/structs_generated.go +++ b/pkg/eolib/protocol/pub/structs_generated.go @@ -1,12 +1,6 @@ package pub -import ( - "fmt" - "github.com/ethanmoffat/eolib-go/pkg/eolib/data" -) - -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf +import "github.com/ethanmoffat/eolib-go/pkg/eolib/data" // EifRecord :: Record of Item data in an Endless Item File. type EifRecord struct { diff --git a/pkg/eolib/protocol/structs_generated.go b/pkg/eolib/protocol/structs_generated.go index 2258910..bab823a 100644 --- a/pkg/eolib/protocol/structs_generated.go +++ b/pkg/eolib/protocol/structs_generated.go @@ -1,12 +1,6 @@ package protocol -import ( - "fmt" - "github.com/ethanmoffat/eolib-go/pkg/eolib/data" -) - -// Ensure fmt import is referenced in generated code -var _ = fmt.Printf +import "github.com/ethanmoffat/eolib-go/pkg/eolib/data" // Coords :: Map coordinates. type Coords struct {