Skip to content

Commit

Permalink
fix multi other type
Browse files Browse the repository at this point in the history
  • Loading branch information
2A5F committed Mar 9, 2025
1 parent 5495e98 commit 055ff7f
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 31 deletions.
64 changes: 42 additions & 22 deletions Coplt.Union.Analyzers/Generators/Templates/TemplateStructUnion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,14 @@ private record struct RecordTypeDefine(
string UnmanagedTypeName,
List<(string Type, int Index)> UnmanagedFields,
List<(string Type, int Index, int NthByKind)> ClassFields,
List<(string Type, int Index)> OtherFields
List<(string Type, int Index, int NthByKind)> OtherFields
);

private int UnmanagedTypeInc = 0;
private int OtherTypeInc = 0;
private int MaxClassTypes = 0;
private readonly Dictionary<string, int> UnmanagedTypes = new();
private readonly Dictionary<string, int> OtherTypes = new();
private readonly Dictionary<string, (int Index, int Max)> OtherTypes = new();
private readonly Dictionary<int, RecordTypeDefine> RecordTypeDefines = new();

private void BuildTypes()
Expand All @@ -294,12 +294,12 @@ private void BuildTypes()
MaxClassTypes = Math.Max(MaxClassTypes, 1);
break;
default:
goto OtherType;
if (!OtherTypes.TryGetValue(@case.Type, out var ot))
ot = (OtherTypeInc++, 1);
OtherTypes[@case.Type] = ot;
break;
}
continue;
OtherType:
if (!OtherTypes.ContainsKey(@case.Type)) OtherTypes.Add(@case.Type, OtherTypeInc++);
continue;
}
else
{
Expand All @@ -308,6 +308,7 @@ private void BuildTypes()
: $"__record_{i}_unmanaged_";
var def = new RecordTypeDefine(unmanaged_struct_name, new(), new(), new());
var class_inc = 0;
var other_type_max_inc = new Dictionary<string, object>();
foreach (var (item, j) in @case.Items.Select(static (a, b) => (a, b)))
{
switch (item.Kind)
Expand All @@ -316,16 +317,29 @@ private void BuildTypes()
def.UnmanagedFields.Add((item.Type, j));
break;
case UnionCaseTypeKind.Class:
{
var nth = class_inc++;
def.ClassFields.Add((item.Type, j, nth));
MaxClassTypes = Math.Max(MaxClassTypes, nth + 1);
break;
}
default: goto OtherType;
}
continue;
OtherType:
def.OtherFields.Add((item.Type, j));
if (!OtherTypes.ContainsKey(item.Type)) OtherTypes.Add(item.Type, OtherTypeInc++);
if (!other_type_max_inc.TryGetValue(item.Type, out var nth_boxed))
{
nth_boxed = 0;
other_type_max_inc.Add(item.Type, nth_boxed);
}
{
var nth = Unsafe.Unbox<int>(nth_boxed)++;
def.OtherFields.Add((item.Type, j, nth));
if (!OtherTypes.TryGetValue(item.Type, out var ot))
ot = (OtherTypeInc++, 1);
else ot.Max = Math.Max(ot.Max, nth + 1);
OtherTypes[item.Type] = ot;
}
continue;
}
if (def.UnmanagedFields.Count > 0)
Expand Down Expand Up @@ -450,15 +464,15 @@ private void GenViews(string impl_name, string tags_name)
$"{space} get => ref global::System.Runtime.CompilerServices.Unsafe.As<object?, {type}>(ref {RoImpl}._c{nth});");
sb.AppendLine($"{space}}}");
}
foreach (var (type, index) in def.OtherFields)
foreach (var (type, index, nth) in def.OtherFields)
{
var ti = OtherTypes[type];
var (ti, _) = OtherTypes[type];
var item = @case.Items[index];
sb.AppendLine($"{space}[global::System.Diagnostics.CodeAnalysis.UnscopedRef]");
sb.AppendLine($"{space}public ref{ro} {type} {item.Name}");
sb.AppendLine($"{space}{{");
sb.AppendLine($"{space} {AggressiveInlining}");
sb.AppendLine($"{space} get => ref {RoImpl}._f{ti};");
sb.AppendLine($"{space} get => ref {RoImpl}._f{ti}_{nth};");
sb.AppendLine($"{space}}}");
}

Expand Down Expand Up @@ -668,7 +682,10 @@ private void GenImpl(string name, string tags_name)

foreach (var kv in OtherTypes)
{
sb.AppendLine($" public {kv.Key} _f{kv.Value};");
for (var i = 0; i < kv.Value.Max; i++)
{
sb.AppendLine($" public {kv.Key} _f{kv.Value.Index}_{i};");
}
}

#endregion
Expand All @@ -693,7 +710,10 @@ private void GenImpl(string name, string tags_name)
sb.AppendLine($" global::System.Runtime.CompilerServices.Unsafe.SkipInit(out this._u);");
foreach (var kv in OtherTypes)
{
sb.AppendLine($" this._f{kv.Value} = default!;");
for (var i = 0; i < kv.Value.Max; i++)
{
sb.AppendLine($" this._f{kv.Value.Index}_{i} = default!;");
}
}
sb.AppendLine($" this._tag = _tag;");
sb.AppendLine($" }}");
Expand Down Expand Up @@ -797,19 +817,19 @@ private void GenMake(string impl_name, string tags_name)
var item = @case.Items[index];
sb.AppendLine($"{space}_impl._c{nth} = {item.Name};");
}
foreach (var (type, index) in def.OtherFields)
foreach (var (type, index, nth) in def.OtherFields)
{
var ti = OtherTypes[type];
var (ti, _) = OtherTypes[type];
var item = @case.Items[index];
sb.AppendLine($"{space}_impl._f{ti} = {item.Name};");
sb.AppendLine($"{space}_impl._f{ti}_{nth} = {item.Name};");
}
}
else if (@case.Type != "void")
{
if (@case.Kind == UnionCaseTypeKind.None)
{
var index = OtherTypes![@case.Type];
sb.AppendLine($"{space}_impl._f{index} = value;");
var (index, _) = OtherTypes![@case.Type];
sb.AppendLine($"{space}_impl._f{index}_0 = value;");
}
else if (@case.Kind == UnionCaseTypeKind.Class)
{
Expand Down Expand Up @@ -858,8 +878,8 @@ void GenField()
}
else if (@case.Kind == UnionCaseTypeKind.None)
{
var index = OtherTypes![@case.Type];
sb.Append($"this._impl._f{index}!");
var (index, _) = OtherTypes![@case.Type];
sb.Append($"this._impl._f{index}_0!");
}
else if (@case.Kind == UnionCaseTypeKind.Class)
{
Expand Down Expand Up @@ -1024,9 +1044,9 @@ private void GenToStr(string tags_name)
{
sb.Append(
$" {tags_name}.{@case.Name} => $\"{{nameof({TypeName})}}.{{nameof({tags_name}.{@case.Name})}}");
if (@case .Type != "void") sb.Append($" {{{{ {{({Self}.{@case.Name})}} }}}}");
if (@case.Type != "void") sb.Append($" {{{{ {{({Self}.{@case.Name})}} }}}}");
}

sb.AppendLine($"\",");
}
sb.AppendLine($" _ => nameof({TypeName}),");
Expand Down
4 changes: 4 additions & 0 deletions Coplt.Union.Analyzers/Properties/launchSettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
"DebugRoslynComponent for Tests": {
"commandName": "DebugRoslynComponent",
"targetProject": "..\\Tests\\Tests.csproj"
},
"DebugRoslynComponent for Utilities": {
"commandName": "DebugRoslynComponent",
"targetProject": "..\\Coplt.Union.Utilities\\Coplt.Union.Utilities.csproj"
}
}
}
2 changes: 1 addition & 1 deletion Coplt.Union.Source/Coplt.Union.Source.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<PropertyGroup>
<TargetFrameworks>netstandard2.0;netstandard2.1;net6.0;net7.0;net8.0;net9.0</TargetFrameworks>
<PackageId>Coplt.Union</PackageId>
<Version>0.13.1</Version>
<Version>0.13.2</Version>
<IsPackable>true</IsPackable>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<NoWarn>CS9113</NoWarn>
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public readonly partial struct Union1
public object? _c0; // All classes will overlap
public object? _c1;
public __unmanaged_ _u; // All unmanaged types will overlap
public (int a, string b) _f0; // Mixed types cannot overlap
public (int a, string b) _f0_0; // Mixed types cannot overlap
public readonly Tags _tag;

[StructLayout(LayoutKind.Explicit)] internal struct __unmanaged_
Expand Down Expand Up @@ -174,7 +174,7 @@ public readonly partial struct Union1
public object? _c0;
public object? _c1;
public __unmanaged_ _u;
public (int a, string b) _f0;
public (int a, string b) _f0_0;
public readonly Tags _tag;

[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
Expand All @@ -183,7 +183,7 @@ public readonly partial struct Union1
this._c0 = null;
this._c1 = null;
global::System.Runtime.CompilerServices.Unsafe.SkipInit(out this._u);
this._f0 = default!;
this._f0_0 = default!;
this._tag = _tag;
}

Expand Down Expand Up @@ -256,7 +256,7 @@ public readonly partial struct Union1
public ref readonly (int a, string b) e
{
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
get => ref global::System.Runtime.CompilerServices.Unsafe.AsRef<Union1.__impl_>(in this._impl)._f0;
get => ref global::System.Runtime.CompilerServices.Unsafe.AsRef<Union1.__impl_>(in this._impl)._f0_0;
}

[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
Expand Down Expand Up @@ -355,7 +355,7 @@ public readonly partial struct Union1
public static Union1 MakeG((int a, string b) value)
{
var _impl = new __impl_(Tags.G);
_impl._f0 = value;
_impl._f0_0 = value;
return new Union1(_impl);
}
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
Expand All @@ -366,7 +366,7 @@ public readonly partial struct Union1
_impl._u._3._1 = b;
_impl._c0 = c;
_impl._c1 = d;
_impl._f0 = e;
_impl._f0_0 = e;
return new Union1(_impl);
}

Expand Down Expand Up @@ -445,7 +445,7 @@ public readonly partial struct Union1
public ref readonly (int a, string b) G
{
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
get => ref !this.IsG ? ref global::System.Runtime.CompilerServices.Unsafe.NullRef<(int a, string b)>() : ref this._impl._f0!;
get => ref !this.IsG ? ref global::System.Runtime.CompilerServices.Unsafe.NullRef<(int a, string b)>() : ref this._impl._f0_0!;
}
[global::System.Diagnostics.CodeAnalysis.UnscopedRef]
public ref readonly VariantHView H
Expand Down
3 changes: 2 additions & 1 deletion Tests/Tree.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using System;
using Coplt.Union;
using Coplt.Union.Misc;

namespace Tests;

[Union]
[Union, UnionSymbol(IsReferenceType = MayBool.True)]
public partial class Tree
{
[UnionTemplate]
Expand Down
10 changes: 10 additions & 0 deletions Tests/Unions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ public partial class Union10
{
public override string ToString() => "Fuck";
}

[Union]
public partial class Union11
{
[UnionTemplate]
private interface Template
{
void Foo((float, string) a, (float, string) b, (int, string) c);
}
}

0 comments on commit 055ff7f

Please sign in to comment.