Skip to content

Commit

Permalink
Improve srcgen formatting using SourceBuilder (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
agocke authored Jan 4, 2025
1 parent ef64b0b commit 9098a6b
Show file tree
Hide file tree
Showing 265 changed files with 4,576 additions and 4,641 deletions.
97 changes: 48 additions & 49 deletions src/generator/DeserializeImplGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Serde
{
internal class DeserializeImplGen
{
internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenDeserialize(
internal static (SourceBuilder, string BaseList) GenDeserialize(
TypeDeclContext typeDeclContext,
GeneratorExecutionContext context,
ITypeSymbol receiverType,
Expand All @@ -27,19 +27,12 @@ internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenDeserialize(

if (receiverType.IsAbstract)
{
var memberDecl = ParseMemberDeclaration(GenUnionDeserializeMethod((INamedTypeSymbol)receiverType))!;
List<BaseTypeSyntax> unionBase = [
// `Serde.IDeserialize<'typeName'>
SimpleBaseType(QualifiedName(IdentifierName("Serde"), GenericName(
Identifier("IDeserialize"),
TypeArgumentList(SeparatedList(new[] { typeSyntax }))
))),
];
return ([ memberDecl ], BaseList(SeparatedList(unionBase)));
var memberDecl = GenUnionDeserializeMethod((INamedTypeSymbol)receiverType);
return (memberDecl, $": global::Serde.IDeserialize<{typeFqn}>");
}

// Generate members for IDeserialize.Deserialize implementation
var members = new List<MemberDeclarationSyntax>();
var members = new SourceBuilder();
List<BaseTypeSyntax> bases = [
// `Serde.IDeserialize<'typeName'>
SimpleBaseType(QualifiedName(IdentifierName("Serde"), GenericName(
Expand All @@ -53,21 +46,20 @@ internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenDeserialize(
bases.Add(SimpleBaseType(ParseTypeName($"Serde.IDeserializeProvider<{typeFqn}>")));

var deserialize = GenerateEnumDeserializeMethod(receiverType, typeSyntax);
members.Add(deserialize);
members.AppendLine(deserialize);

var deserializeInstance = ParseMemberDeclaration($"""
members.AppendLine($"""
static IDeserialize<{typeFqn}> IDeserializeProvider<{typeFqn}>.DeserializeInstance
=> {typeFqn}Proxy.Instance;
""")!;
members.Add(deserializeInstance);
""");
}
else
{
var method = GenerateCustomDeserializeMethod(typeDeclContext, context, receiverType, typeSyntax, inProgress);
members.Add(method);
members.AppendLine(method);
}
var baseList = BaseList(SeparatedList(bases));
return (members, baseList);
return (members, baseList.ToFullString());
}

/// <summary>
Expand All @@ -92,7 +84,7 @@ static IDeserialize<{typeFqn}> IDeserializeProvider<{typeFqn}>.DeserializeInstan
/// }
/// </code>
/// </summary>
private static string GenUnionDeserializeMethod(INamedTypeSymbol type)
private static SourceBuilder GenUnionDeserializeMethod(INamedTypeSymbol type)
{
Debug.Assert(type.IsAbstract);

Expand All @@ -112,7 +104,7 @@ private static string GenUnionDeserializeMethod(INamedTypeSymbol type)
membersBuilder.AppendLine($"{i} => de.ReadValue<{m.ToDisplayString()}, {SerdeInfoGenerator.GetUnionProxyName(m)}>({i}),");
}

var src = $$"""
var src = new SourceBuilder($$"""
{{typeFqn}} IDeserialize<{{typeFqn}}>.Deserialize(IDeserializer deserializer)
{
var serdeInfo = global::Serde.SerdeInfoProvider.GetInfo<{{typeFqn}}>();
Expand All @@ -132,7 +124,7 @@ private static string GenUnionDeserializeMethod(INamedTypeSymbol type)
}
return _l_result;
}
""";
""");
return src;
}

Expand All @@ -158,7 +150,7 @@ private static string GenUnionDeserializeMethod(INamedTypeSymbol type)
/// }
/// </code>
/// </summary>
private static MethodDeclarationSyntax GenerateEnumDeserializeMethod(
private static SourceBuilder GenerateEnumDeserializeMethod(
ITypeSymbol type,
TypeSyntax typeSyntax)
{
Expand All @@ -173,7 +165,7 @@ private static MethodDeclarationSyntax GenerateEnumDeserializeMethod(
<= 64 => "ulong",
_ => throw new InvalidOperationException("Too many members in type")
};
var src = $$"""
var src = new SourceBuilder($$"""
{{typeFqn}} IDeserialize<{{typeFqn}}>.Deserialize(IDeserializer deserializer)
{
var serdeInfo = global::Serde.SerdeInfoProvider.GetInfo<{{typeFqn}}Proxy>();
Expand All @@ -189,8 +181,8 @@ private static MethodDeclarationSyntax GenerateEnumDeserializeMethod(
_ => throw new InvalidOperationException($"Unexpected index: {index}")
};
}
""";
return (MethodDeclarationSyntax)ParseMemberDeclaration(src)!;
""");
return src;
}

/// <summary>
Expand All @@ -214,7 +206,7 @@ private static MethodDeclarationSyntax GenerateEnumDeserializeMethod(
/// }
/// </code>
/// </summary>
private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
private static SourceBuilder GenerateCustomDeserializeMethod(
TypeDeclContext typeDeclContext,
GeneratorExecutionContext context,
ITypeSymbol type,
Expand All @@ -233,7 +225,7 @@ private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
_ => throw new InvalidOperationException("Too many members in type")
};
var (cases, locals, requiredMask) = InitCasesAndLocals();
string typeCreationExpr = GenerateTypeCreation(context, typeFqn, type, members, requiredMask);
var typeCreationExpr = GenerateTypeCreation(context, typeFqn, type, members, requiredMask);

const string typeInfoLocalName = "_l_serdeInfo";
const string indexLocalName = "_l_index_";
Expand All @@ -243,7 +235,7 @@ private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
? $"var {IndexErrorName}"
: "_";

var methodText = $$"""
var methodText = new SourceBuilder($$"""
{{typeFqn}} Serde.IDeserialize<{{typeFqn}}>.Deserialize(IDeserializer deserializer)
{
{{locals}}
Expand All @@ -256,18 +248,18 @@ private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
{
switch ({{indexLocalName}})
{
{{cases}}
{{cases}}
}
}
{{typeCreationExpr}}
return newType;
}
""";
return (MethodDeclarationSyntax)ParseMemberDeclaration(methodText)!;
""");
return methodText;

(string Cases, string Locals, string AssignedMask) InitCasesAndLocals()
{
var casesBuilder = new StringBuilder();
var casesBuilder = new SourceBuilder();
var localsBuilder = new StringBuilder();
long assignedMaskValue = 0;
var skippedIndices = new List<int>();
Expand Down Expand Up @@ -336,14 +328,18 @@ private static MethodDeclarationSyntax GenerateCustomDeserializeMethod(
case {i}:
""");
}
casesBuilder.AppendLine($"""
casesBuilder.AppendLine(
$"""
case Serde.IDeserializeType.IndexNotFound:
{unknownMemberBehavior}
""");
casesBuilder.AppendLine($"""
"""
);
casesBuilder.Append(
$"""
default:
throw new InvalidOperationException("Unexpected index: " + {indexLocalName});
""");
"""
);
return (casesBuilder.ToString(),
localsBuilder.ToString(),
"0b" + Convert.ToString(assignedMaskValue, 2));
Expand Down Expand Up @@ -380,7 +376,7 @@ static string GetReadValueCall(
/// each member in the initializer. If there is no parameterlss constructor, there
/// must be a primary constructor.
/// </summary>
private static string GenerateTypeCreation(
private static SourceBuilder GenerateTypeCreation(
GeneratorExecutionContext context,
string typeName,
ITypeSymbol type,
Expand Down Expand Up @@ -412,11 +408,10 @@ private static string GenerateTypeCreation(
if (parameterlessCtor is null && primaryCtor is null)
{
context.ReportDiagnostic(CreateDiagnostic(DiagId.ERR_MissingPrimaryCtor, type.Locations[0]));
return $"var newType = new {typeName}();";
return new SourceBuilder($"var newType = new {typeName}();");
}

var assignmentMembers = new List<DataMemberSymbol>(members);
var assignments = new StringBuilder();
var parameters = new StringBuilder();
if (primaryCtor is not null)
{
Expand All @@ -432,24 +427,28 @@ private static string GenerateTypeCreation(
}
}

var typeCreation = new SourceBuilder(
$$"""
if (({{AssignedVarName}} & {{assignedMask}}) != {{assignedMask}})
{
throw Serde.DeserializeException.UnassignedMember();
}
var newType = new {{typeName}}({{parameters}}) {
"""
);
typeCreation.Indent();
foreach (var m in assignmentMembers)
{
if (m.SkipDeserialize)
{
continue;
}
assignments.AppendLine($"{m.Name} = {GetLocalName(m)},");
typeCreation.AppendLine($"{m.Name} = {GetLocalName(m)},");
}
var mask = new string('1', members.Count);
return $$"""
if (({{AssignedVarName}} & {{assignedMask}}) != {{assignedMask}})
{
throw Serde.DeserializeException.UnassignedMember();
}
var newType = new {{typeName}}({{parameters}}) {
{{assignments}}
};
""";
typeCreation.Dedent();
typeCreation.AppendLine("};");
return typeCreation;
}

private static string GetLocalName(DataMemberSymbol m) => "_l_" + m.Name.ToLower();
Expand Down
Loading

0 comments on commit 9098a6b

Please sign in to comment.