Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions crates/bindings-csharp/BSATN.Codegen/Type.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ public MemberDeclaration(ISymbol member, ITypeSymbol type, DiagReporter diag)
public MemberDeclaration(IFieldSymbol field, DiagReporter diag)
: this(field, field.Type, diag) { }

public string Identifier => EscapeIdentifier(Name);

public static string GenerateBsatnFields(
Accessibility visibility,
IEnumerable<MemberDeclaration> members
Expand All @@ -431,7 +433,7 @@ IEnumerable<MemberDeclaration> members
return string.Join(
"\n ",
members.Select(m =>
$"{visStr} static readonly {m.Type.ToBSATNString()} {m.Name}{TypeUse.BsatnFieldSuffix} = new();"
$"{visStr} static readonly {m.Type.ToBSATNString()} {m.Identifier}{TypeUse.BsatnFieldSuffix} = new();"
)
);
}
Expand All @@ -442,7 +444,7 @@ public static string GenerateDefs(IEnumerable<MemberDeclaration> members) =>
// we can't use nameof(m.Type.BsatnFieldName) because the bsatn field name differs from the logical name
// assigned in the type.
members.Select(m =>
$"new(\"{m.Name}\", {m.Name}{TypeUse.BsatnFieldSuffix}.GetAlgebraicType(registrar))"
$"new(\"{m.Name}\", {m.Identifier}{TypeUse.BsatnFieldSuffix}.GetAlgebraicType(registrar))"
)
);
}
Expand Down Expand Up @@ -569,10 +571,10 @@ public Scope.Extensions ToExtensions()
// To avoid this, we append an underscore to the field name.
// In most cases the field name shouldn't matter anyway as you'll idiomatically use pattern matching to extract the value.
$$"""
public sealed record {{m.Name}}({{m.Type.Name}} {{m.Name}}_) : {{ShortName}}
public sealed record {{m.Identifier}}({{m.Type.Name}} {{m.Identifier}}_) : {{ShortName}}
{
public override string ToString() =>
$"{{m.Name}}({ SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Name}}_) })";
$"{{m.Name}}({ SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Identifier}}_) })";
}

"""
Expand All @@ -585,7 +587,7 @@ public override string ToString() =>
{{string.Join(
"\n ",
bsatnDecls.Select((m, i) =>
$"{i} => new {m.Name}({m.Name}{TypeUse.BsatnFieldSuffix}.Read(reader)),"
$"{i} => new {m.Identifier}({m.Identifier}{TypeUse.BsatnFieldSuffix}.Read(reader)),"
)
)}}
_ => throw new System.InvalidOperationException("Invalid tag value, this state should be unreachable.")
Expand All @@ -597,9 +599,9 @@ public override string ToString() =>
{{string.Join(
"\n",
bsatnDecls.Select((m, i) => $"""
case {m.Name}(var inner):
case {m.Identifier}(var inner):
writer.Write((byte){i});
{m.Name}{TypeUse.BsatnFieldSuffix}.Write(writer, inner);
{m.Identifier}{TypeUse.BsatnFieldSuffix}.Write(writer, inner);
break;
"""))}}
}
Expand All @@ -615,7 +617,7 @@ public override string ToString() =>
var hashName = $"___hash{member.Name}";

return $"""
case {member.Name}(var inner):
case {member.Identifier}(var inner):
{member.Type.GetHashCodeStatement("inner", hashName)}
return {hashName};
""";
Expand All @@ -634,14 +636,14 @@ public override string ToString() =>
public void ReadFields(System.IO.BinaryReader reader) {
{{string.Join(
"\n",
bsatnDecls.Select(m => $" {m.Name} = BSATN.{m.Name}{TypeUse.BsatnFieldSuffix}.Read(reader);")
bsatnDecls.Select(m => $" {m.Identifier} = BSATN.{m.Identifier}{TypeUse.BsatnFieldSuffix}.Read(reader);")
)}}
}

public void WriteFields(System.IO.BinaryWriter writer) {
{{string.Join(
"\n",
bsatnDecls.Select(m => $" BSATN.{m.Name}{TypeUse.BsatnFieldSuffix}.Write(writer, {m.Name});")
bsatnDecls.Select(m => $" BSATN.{m.Identifier}{TypeUse.BsatnFieldSuffix}.Write(writer, {m.Identifier});")
)}}
}

Expand All @@ -661,7 +663,7 @@ object SpacetimeDB.BSATN.IStructuralReadWrite.GetSerializer() {
public override string ToString() =>
$"{{ShortName}} {{start}} {{string.Join(
", ",
bsatnDecls.Select(m => $$"""{{m.Name}} = {SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Name}})}""")
bsatnDecls.Select(m => $$"""{{m.Name}} = {SpacetimeDB.BSATN.StringUtil.GenericToString({{m.Identifier}})}""")
)}} {{end}}";
"""
);
Expand All @@ -680,7 +682,7 @@ public override string ToString() =>
var declHashName = (MemberDeclaration decl) => $"___hash{decl.Name}";

getHashCode = $$"""
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.GetHashCodeStatement(decl.Name, declHashName(decl))))}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.GetHashCodeStatement(decl.Identifier, declHashName(decl))))}}
return {{JoinOrValue(
" ^\n ",
bsatnDecls.Select(declHashName),
Expand Down Expand Up @@ -735,7 +737,7 @@ public override int GetHashCode()
public bool Equals({{fullNameMaybeRef}} that)
{
{{(Scope.IsStruct ? "" : "if (((object?)that) == null) { return false; }\n ")}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.EqualsStatement($"this.{decl.Name}", $"that.{decl.Name}", declEqualsName(decl))))}}
{{string.Join("\n", bsatnDecls.Select(decl => decl.Type.EqualsStatement($"this.{decl.Identifier}", $"that.{decl.Identifier}", declEqualsName(decl))))}}
return {{JoinOrValue(
" &&\n ",
bsatnDecls.Select(declEqualsName),
Expand Down
13 changes: 13 additions & 0 deletions crates/bindings-csharp/BSATN.Codegen/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,26 @@ public readonly record struct EquatableArray<T>(ImmutableArray<T> Array) : IEnum
.AddMemberOptions(SymbolDisplayMemberOptions.IncludeContainingType)
.AddMiscellaneousOptions(
SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
| SymbolDisplayMiscellaneousOptions.EscapeKeywordIdentifiers
);

public static string SymbolToName(ISymbol symbol)
{
return symbol.ToDisplayString(SymbolFormat);
}

public static string EscapeIdentifier(string name)
{
if (name.Length > 0 && name[0] == '@')
{
return name;
}

var kind = SyntaxFacts.GetKeywordKind(name);
var contextualKind = SyntaxFacts.GetContextualKeywordKind(name);
return kind != SyntaxKind.None || contextualKind != SyntaxKind.None ? $"@{name}" : name;
}

public static void RegisterSourceOutputs(
this IncrementalValuesProvider<Scope.Extensions> methods,
IncrementalGeneratorInitializationContext context
Expand Down
79 changes: 79 additions & 0 deletions crates/bindings-csharp/Codegen.Tests/Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,85 @@ public static async Task SettingsAndExplicitNames()
AssertGeneratedCodeDoesNotUseInternalBound(compilationAfterGen);
}

[Fact]
public static async Task CSharpKeywordIdentifiersAreEscapedInGeneratedCode()
{
var fixture = await Fixture.Compile("server");

const string source =
"""
using SpacetimeDB;

[SpacetimeDB.Table]
public partial struct KeywordTable
{
[SpacetimeDB.PrimaryKey]
public ulong @class;

public int @params;
}

[SpacetimeDB.Table(Accessor = "class")]
public partial struct AccessorKeywordTable
{
[SpacetimeDB.PrimaryKey]
[SpacetimeDB.Index.BTree(Accessor = "class")]
public int Id;
}

[SpacetimeDB.Table]
public partial struct @class
{
[SpacetimeDB.PrimaryKey]
public int Id;
}

public static partial class KeywordApis
{
[SpacetimeDB.Reducer]
public static void KeywordReducer(ReducerContext ctx, int @params, string @class)
{
_ = @params;
_ = @class;
}

[SpacetimeDB.Reducer]
public static void @class(ReducerContext ctx)
{
}

[SpacetimeDB.Procedure]
public static int KeywordProcedure(ProcedureContext ctx, int @params, int @class)
{
return @params + @class;
}

[SpacetimeDB.Procedure]
public static void @params(ProcedureContext ctx)
{
}
}
""";

var parseOptions = new CSharpParseOptions(fixture.SampleCompilation.LanguageVersion);
var tree = CSharpSyntaxTree.ParseText(source, parseOptions, path: "KeywordNames.cs");
var compilation = fixture.SampleCompilation.AddSyntaxTrees(tree);

var driver = CSharpGeneratorDriver.Create(
[new SpacetimeDB.Codegen.Type().AsSourceGenerator(), new SpacetimeDB.Codegen.Module().AsSourceGenerator()],
driverOptions: new(
disabledOutputs: IncrementalGeneratorOutputKind.None,
trackIncrementalGeneratorSteps: true
),
parseOptions: parseOptions
);

var runResult = driver.RunGenerators(compilation).GetRunResult();
var compilationAfterGen = compilation.AddSyntaxTrees(runResult.GeneratedTrees);

Assert.Empty(GetCompilationErrors(compilationAfterGen));
}

[Fact]
public static async Task TestDiagnostics()
{
Expand Down
Loading