From 37627e266e86945104d7642c633f305ab836bd2d Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Tue, 19 Dec 2023 19:21:28 -0800 Subject: [PATCH] Support named services (#153) --- .../ContainerTests.cs | 33 + .../Mocks/ServiceImplementation.cs | 4 + .../Mocks/ServiceImplementationWithNamed.cs | 12 + .../MEDIContainerTests.cs | 43 +- src/Jab.Tests/DiagnosticsTest.cs | 118 ++- src/Jab.Tests/GeneratorAnalyzerVerifier.cs | 1 + src/Jab/ArrayServiceCallSite.cs | 4 +- src/Jab/Attributes.cs | 42 +- src/Jab/ConstructorCallSite.cs | 4 +- src/Jab/ContainerGenerator.cs | 190 +++-- src/Jab/DiagnosticDescriptors.cs | 31 +- src/Jab/ErrorCallSite.cs | 2 +- src/Jab/FactoryCallSite.cs | 2 +- src/Jab/GetServiceCallCandidate.cs | 4 +- src/Jab/KnownTypes.cs | 139 ++++ src/Jab/MemberCallSite.cs | 2 +- src/Jab/ScopeFactoryCallSite.cs | 2 +- src/Jab/ServiceCallSite.cs | 6 +- src/Jab/ServiceIdentity.cs | 18 + src/Jab/ServiceProviderBuilder.cs | 685 +++++++++--------- src/Jab/ServiceProviderCallSite.cs | 2 +- src/Jab/ServiceProviderDescription.cs | 12 +- src/Jab/ServiceProviderIsServiceCallSite.cs | 2 +- src/Jab/ServiceRegistration.cs | 11 +- 24 files changed, 950 insertions(+), 419 deletions(-) create mode 100644 src/Jab.FunctionalTests.Common/Mocks/ServiceImplementationWithNamed.cs create mode 100644 src/Jab/KnownTypes.cs create mode 100644 src/Jab/ServiceIdentity.cs diff --git a/src/Jab.FunctionalTests.Common/ContainerTests.cs b/src/Jab.FunctionalTests.Common/ContainerTests.cs index 62f472e..d95b3de 100644 --- a/src/Jab.FunctionalTests.Common/ContainerTests.cs +++ b/src/Jab.FunctionalTests.Common/ContainerTests.cs @@ -1241,6 +1241,39 @@ internal partial class SupportsInstancePropertyFactoriesOnModulesContainer { Func Instance = () => new ServiceImplementation(); } + + [Fact] + public void SupportsNamedServices() + { + SupportsNamedServicesContainer c = new(); + + var notNamed = c.GetService(); + Assert.IsType(notNamed); + + var named = c.GetService("Named"); + Assert.IsType(named); + + var onlyNamed = c.GetService("OnlyNamed"); + Assert.IsType(onlyNamed); + + var service = c.GetService>(); + Assert.IsType(service.InnerService); + Assert.Same(named, service.InnerService); + + var services = c.GetService>(); + var single = Assert.Single(services); + Assert.Same(notNamed, single); + } + + [ServiceProvider] + [Singleton(typeof(IService), typeof(ServiceImplementation))] + [Singleton(typeof(IService), typeof(ServiceImplementation), Name="Named")] + [Singleton(typeof(IService), typeof(ServiceImplementation2), Name="Named")] + [Singleton(typeof(IAnotherService), typeof(AnotherServiceImplementation), Name="OnlyNamed")] + [Singleton(typeof(ServiceImplementationWithNamed))] + internal partial class SupportsNamedServicesContainer + { + } } } diff --git a/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementation.cs b/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementation.cs index 837ad5a..0edd416 100644 --- a/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementation.cs +++ b/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementation.cs @@ -6,6 +6,10 @@ internal class ServiceImplementation : IService, IService1, IService2, IService3 { } + internal class ServiceImplementation2 : IService, IService1, IService2, IService3 + { + } + internal class ServiceImplementation : IService { public ServiceImplementation(T innerService) diff --git a/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementationWithNamed.cs b/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementationWithNamed.cs new file mode 100644 index 0000000..1714bc1 --- /dev/null +++ b/src/Jab.FunctionalTests.Common/Mocks/ServiceImplementationWithNamed.cs @@ -0,0 +1,12 @@ +using Jab; + +namespace JabTests; + +public class ServiceImplementationWithNamed: IService +{ + public T InnerService { get; } + public ServiceImplementationWithNamed([FromNamedServices("Named")] T innerService) + { + InnerService = innerService; + } +} \ No newline at end of file diff --git a/src/Jab.FunctionalTests.MEDI/MEDIContainerTests.cs b/src/Jab.FunctionalTests.MEDI/MEDIContainerTests.cs index f0e859e..12a9c0e 100644 --- a/src/Jab.FunctionalTests.MEDI/MEDIContainerTests.cs +++ b/src/Jab.FunctionalTests.MEDI/MEDIContainerTests.cs @@ -47,7 +47,7 @@ public void CanUseIsService() { CanUseIsServiceContainer c = new(); IServiceProviderIsService iss = c; - + Assert.True(iss.IsService(typeof(IServiceProvider))); Assert.True(iss.IsService(typeof(IServiceProviderIsService))); Assert.True(iss.IsService(typeof(IServiceScopeFactory))); @@ -65,7 +65,7 @@ internal partial class CanUseIsServiceContainer public void CanResolveIsService() { CanUseIsServiceContainer c = new(); - + Assert.True(c.GetService().IsService(typeof(IServiceProvider))); Assert.Same(c, c.CreateScope().GetService()); Assert.True(c.CreateScope().GetService().IsService(typeof(IServiceProvider))); @@ -76,5 +76,44 @@ internal partial class CanResolveIsServiceContainer { } #endif + +#if NET8_OR_GREATER + [Fact] + public void SupportsKeyedServices() + { + SupportsKeyedServicesContainer c = new(); + + Assert.IsAssignableFrom(c); + + Assert.NotNull(c.GetKeyedService("Key")); + Assert.NotNull(c.GetRequiredKeyedService("Key")); + + Assert.Null(c.GetKeyedService("Bla")); + Assert.Null(c.GetKeyedService("Bla")); + Assert.Throws(() => c.GetRequiredKeyedService("Bla")); + Assert.Throws(() => c.GetRequiredKeyedService("Bla")); + + var serviceWithKeyedParameter = c.GetService>(); + Assert.NotNull(serviceWithKeyedParameter); + Assert.NotNull(serviceWithKeyedParameter.InnerService); + } + + [ServiceProvider] + [Singleton(typeof(ServiceImplementation), Name="Key")] + [Singleton(typeof(ServiceWithKeyedParameter))] + internal partial class SupportsKeyedServicesContainer + { + } + + internal class ServiceWithKeyedParameter + { + public T InnerService { get; } + + public ServiceWithKeyedParameter([FromKeyedServices(typeof(string))] T innerService) + { + InnerService = innerService; + } + } +#endif } } diff --git a/src/Jab.Tests/DiagnosticsTest.cs b/src/Jab.Tests/DiagnosticsTest.cs index ba0501b..a113b98 100644 --- a/src/Jab.Tests/DiagnosticsTest.cs +++ b/src/Jab.Tests/DiagnosticsTest.cs @@ -110,6 +110,25 @@ await Verify.VerifyAnalyzerAsync(testCode, .WithArguments("IDependency", "Service")); } + + [Fact] + public async Task ProducesJAB0019WhenRequiredNamedDependencyNotFound() + { + string testCode = $@" +class Dependency {{ }} +class Service {{ public Service([FromNamedServices(""Named"")] Dependency dep) {{}} }} +[ServiceProvider] +[{{|#1:Transient(typeof(Service))|}}] +[Transient(typeof(Dependency))] +public partial class Container {{}} +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0019") + .WithLocation(1) + .WithArguments("Dependency", "Named", "Service")); + } + [Fact] public async Task ProducesJAB0002WhenRequiredDependenciesNotFound() { @@ -204,9 +223,9 @@ public async Task ProducesJAB0008WhenCircularChainDetected() interface IService {{}} class FirstService {{ public FirstService(IService s) {{}} }} class Service : IService {{ public Service(AnotherService s) {{}} }} -class AnotherService {{ public AnotherService(IService s) {{}} }} +class AnotherService {{ public AnotherService({{|#1:IService|}} s) {{}} }} [ServiceProvider] -[{{|#1:Transient(typeof(FirstService))|}}] +[Transient(typeof(FirstService))] [Transient(typeof(IService), typeof(Service))] [Transient(typeof(AnotherService))] public partial class Container {{}} @@ -215,7 +234,7 @@ await Verify.VerifyAnalyzerAsync(testCode, DiagnosticResult .CompilerError("JAB0008") .WithLocation(1) - .WithArguments("FirstService", "IService", "FirstService -> IService -> Service -> AnotherService -> IService")); + .WithArguments("IService", "FirstService -> IService -> Service -> AnotherService -> IService")); } [Fact] @@ -249,21 +268,31 @@ await Verify.VerifyAnalyzerAsync(testCode, } [Fact] - public async Task ProducesJAB0010IfGetServiceCallTypeUnregistered() + public async Task ProducesJAB0010OrJAB0018IfGetServiceCallTypeUnregistered() { string testCode = $@" interface IService {{}} [ServiceProvider] public partial class Container {{ public T GetService() => default; - public static void Main() {{ var container = new Container(); {{|#1:container.GetService()|}}; }} + public T GetService(string name) => default; + public static void Main() {{ + var container = new Container(); + {{|#1:container.GetService()|}}; + {{|#2:container.GetService(""Named"")|}}; + }} }} "; await Verify.VerifyAnalyzerAsync(testCode, DiagnosticResult .CompilerError("JAB0010") .WithLocation(1) - .WithArguments("IService")); + .WithArguments("IService"), + + DiagnosticResult + .CompilerError("JAB0018") + .WithLocation(2) + .WithArguments("IService", "Named")); } [Fact] @@ -281,6 +310,81 @@ await Verify.VerifyAnalyzerAsync(testCode, .WithArguments("IService")); } + [Fact] + public async Task ProducesDiagnosticWhenServiceNameNotAlphanumeric() + { + string testCode = @" +public class Service {} +[ServiceProvider] +[{|#1:Singleton(typeof(Service), Name = """")|}] +[{|#2:Singleton(typeof(Service), Name = ""'"")|}] +[{|#3:Singleton(typeof(Service), Name = ""1a"")|}] +[Singleton(typeof(Service), Name = ""aA10"")] +public partial class Container {} +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0015") + .WithLocation(1) + .WithArguments(""), + + DiagnosticResult + .CompilerError("JAB0015") + .WithLocation(2) + .WithArguments("'"), + + DiagnosticResult + .CompilerError("JAB0015") + .WithLocation(3) + .WithArguments("1a")); + } + + [Fact] + public async Task ProducesDiagnosticWhenBuiltInServicesRequestedAsNamed() + { + string testCode = @" +public class Service { + public Service( + [FromNamedServices(""A"")] {|#1:IServiceProvider|} sp + ) {} +} + +[ServiceProvider] +[Singleton(typeof(Service))] +public partial class Container {} +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0016") + .WithLocation(1) + .WithArguments("System.IServiceProvider")); + } + + + [Fact] + public async Task ProducesDiagnosticWhenImplicitIEnumerableRequestedAsNamed() + { + string testCode = @" +public class Service1 {} +public class Service { + public Service( + [FromNamedServices(""A"")] {|#1:IEnumerable|} s, + IEnumerable ss + ) {} +} + +[ServiceProvider] +[Singleton(typeof(Service))] +[Singleton(typeof(Service1))] +public partial class Container {} +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0017") + .WithLocation(1) + .WithArguments("System.Collections.Generic.IEnumerable")); + } + [Fact] public async Task ProducesJAB0013WhenNullableNonOptionalDependencyNotFound() { @@ -316,7 +420,7 @@ public partial class Container {{}} await Verify.VerifyAnalyzerAsync(testCode, DiagnosticResult .CompilerError("JAB0014") - .WithSeverity(DiagnosticSeverity.Warning) + .WithSeverity(DiagnosticSeverity.Info) .WithLocation(1) .WithArguments("IDependency?", "Service")); } diff --git a/src/Jab.Tests/GeneratorAnalyzerVerifier.cs b/src/Jab.Tests/GeneratorAnalyzerVerifier.cs index 7f322aa..a367ab4 100644 --- a/src/Jab.Tests/GeneratorAnalyzerVerifier.cs +++ b/src/Jab.Tests/GeneratorAnalyzerVerifier.cs @@ -14,6 +14,7 @@ public static Task VerifyAnalyzerAsync(string source, params DiagnosticResult[] { source = @" using System; +using System.Collections.Generic; using Jab; " + source; var test = new GeneratorAnalyzerTest diff --git a/src/Jab/ArrayServiceCallSite.cs b/src/Jab/ArrayServiceCallSite.cs index 46b7821..480e233 100644 --- a/src/Jab/ArrayServiceCallSite.cs +++ b/src/Jab/ArrayServiceCallSite.cs @@ -2,8 +2,8 @@ internal record ArrayServiceCallSite: ServiceCallSite { - public ArrayServiceCallSite(INamedTypeSymbol serviceType, INamedTypeSymbol implementationType, ITypeSymbol itemType, ServiceCallSite[] items, ServiceLifetime lifetime) - : base(serviceType, implementationType, lifetime, 0, false) + public ArrayServiceCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ITypeSymbol itemType, ServiceCallSite[] items, ServiceLifetime lifetime) + : base(identity, implementationType, lifetime, false) { ItemType = itemType; Items = items; diff --git a/src/Jab/Attributes.cs b/src/Jab/Attributes.cs index fa831f7..d015414 100644 --- a/src/Jab/Attributes.cs +++ b/src/Jab/Attributes.cs @@ -60,6 +60,8 @@ class SingletonAttribute: Attribute { public Type ServiceType { get; } + public string? Name { get; set; } + public Type? ImplementationType { get; } public string? Instance { get; set; } @@ -88,6 +90,7 @@ public SingletonAttribute(Type serviceType, Type implementationType) class TransientAttribute : Attribute { public Type ServiceType { get; } + public string? Name { get; set; } public Type? ImplementationType { get; } @@ -115,6 +118,7 @@ public TransientAttribute(Type serviceType, Type implementationType) class ScopedAttribute : Attribute { public Type ServiceType { get; } + public string? Name { get; set; } public Type? ImplementationType { get; } @@ -132,6 +136,23 @@ public ScopedAttribute(Type serviceType, Type implementationType) } } + + [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = true)] +#if JAB_ATTRIBUTES_PACKAGE + public +#else + internal +#endif + class FromNamedServicesAttribute : Attribute + { + public string? Name { get; set; } + + public FromNamedServicesAttribute(string name) + { + Name = name; + } + } + #if GENERIC_ATTRIBUTES [AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, AllowMultiple = true, Inherited = true)] #if JAB_ATTRIBUTES_PACKAGE @@ -250,13 +271,26 @@ interface IServiceProvider #else [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Jab", null)] internal +#endif + interface INamedServiceProvider + { + T GetService(string name); + } + +#if JAB_ATTRIBUTES_PACKAGE + public +#else + internal #endif static class JabHelpers { - public static InvalidOperationException CreateServiceNotFoundException() - { - return new InvalidOperationException($"Service Type {typeof(T)} not registered"); - } + public static InvalidOperationException CreateServiceNotFoundException(string? name = null) => + CreateServiceNotFoundException(typeof(T), name); + public static InvalidOperationException CreateServiceNotFoundException(Type type, string? name = null) => + new InvalidOperationException( + name != null ? + $"Service with type {type} and name {name} not registered" : + $"Service with type {type} not registered"); } } diff --git a/src/Jab/ConstructorCallSite.cs b/src/Jab/ConstructorCallSite.cs index aec3c3c..aa4f560 100644 --- a/src/Jab/ConstructorCallSite.cs +++ b/src/Jab/ConstructorCallSite.cs @@ -2,8 +2,8 @@ internal record ConstructorCallSite : ServiceCallSite { - public ConstructorCallSite(INamedTypeSymbol serviceType, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, int reverseIndex, bool? isDisposable) - : base(serviceType, implementationType, lifetime, reverseIndex, isDisposable) + public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, int? reverseIndex, bool? isDisposable) + : base(identity, implementationType, lifetime, isDisposable) { Parameters = parameters; OptionalParameters = optionalParameters; diff --git a/src/Jab/ContainerGenerator.cs b/src/Jab/ContainerGenerator.cs index 693e2ff..a0657e4 100644 --- a/src/Jab/ContainerGenerator.cs +++ b/src/Jab/ContainerGenerator.cs @@ -23,7 +23,7 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen if (serviceCallSite.Lifetime != ServiceLifetime.Transient) { - var cacheLocation = GetCacheLocation(serviceCallSite); + var cacheLocation = GetCacheLocation(serviceCallSite.Identity); codeWriter.Line($"if ({cacheLocation} == default)"); codeWriter.Line($"lock (this)"); using (codeWriter.Scope($"if ({cacheLocation} == default)")) @@ -55,11 +55,11 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen } } - private void WriteResolutionCall(CodeWriter codeWriter, ServiceCallSite other, string reference) + private void WriteResolutionCall(CodeWriter codeWriter, ServiceIdentity other, string reference) { if (other.IsMainImplementation) { - codeWriter.Append($"{reference}.GetService<{other.ServiceType}>()"); + codeWriter.Append($"{reference}.GetService<{other.Type}>()"); } else { @@ -114,14 +114,14 @@ private void AppendParameters(CodeWriter codeWriter, ServiceCallSite[] parameter { foreach (var parameter in parameters) { - WriteResolutionCall(codeWriter, parameter, "this"); + WriteResolutionCall(codeWriter, parameter.Identity, "this"); codeWriter.AppendRaw(", "); } foreach (var pair in optionalParameters) { codeWriter.Append($"{pair.Key.Name}: "); - WriteResolutionCall(codeWriter, pair.Value, "this"); + WriteResolutionCall(codeWriter, pair.Value.Identity, "this"); codeWriter.AppendRaw(", "); } codeWriter.RemoveTrailingComma(); @@ -163,7 +163,7 @@ private void GenerateCallSite(CodeWriter codeWriter, string rootReference, Servi { foreach (var item in arrayServiceCallSite.Items) { - WriteResolutionCall(codeWriter, item, "this"); + WriteResolutionCall(codeWriter, item.Identity, "this"); w.LineRaw(", "); } } @@ -212,14 +212,14 @@ private void Execute(GeneratorContext context) foreach (var rootService in root.RootCallSites) { - var rootServiceType = rootService.ServiceType; - if (rootService.IsMainImplementation) + var rootServiceType = rootService.Identity.Type; + if (rootService.Identity.IsMainImplementation) { codeWriter.Append($"{rootServiceType} IServiceProvider<{rootServiceType}>.GetService()"); } else { - codeWriter.Append($"private {rootServiceType} {GetResolutionServiceName(rootService)}()"); + codeWriter.Append($"private {rootServiceType} {GetResolutionServiceName(rootService.Identity)}()"); } if (rootService.Lifetime == ServiceLifetime.Scoped) @@ -241,12 +241,10 @@ private void Execute(GeneratorContext context) codeWriter.Line(); } + WriteNamedServiceProvider(codeWriter, root); WriteServiceProvider(codeWriter, root); WriteDispose(codeWriter, root, isScoped: false); - - codeWriter.Line($"[DebuggerHidden]"); - codeWriter.Line($"public T GetService() => this is IServiceProvider provider ? provider.GetService() : throw CreateServiceNotFoundException();"); - codeWriter.Line(); + WritePublicGetServiceMethods(codeWriter); codeWriter.Line($"public Scope CreateScope() => new Scope(this);"); codeWriter.Line(); @@ -271,7 +269,7 @@ private void Execute(GeneratorContext context) { codeWriter.Line($" ||"); } - codeWriter.Append($"typeof({rootService.ServiceType}) == service"); + codeWriter.Append($"typeof({rootService.Identity.Type}) == service"); } if (first) { @@ -294,22 +292,20 @@ private void Execute(GeneratorContext context) } codeWriter.Line(); - codeWriter.Line($"[DebuggerHidden]"); - codeWriter.Line($"public T GetService() => this is IServiceProvider provider ? provider.GetService() : throw CreateServiceNotFoundException();"); - codeWriter.Line(); + WritePublicGetServiceMethods(codeWriter); foreach (var rootService in root.RootCallSites) { - var rootServiceType = rootService.ServiceType; + var rootServiceType = rootService.Identity.Type; - using (rootService.IsMainImplementation ? + using (rootService.Identity.IsMainImplementation ? codeWriter.Scope($"{rootServiceType} IServiceProvider<{rootServiceType}>.GetService()") : - codeWriter.Scope($"private {rootServiceType} {GetResolutionServiceName(rootService)}()")) + codeWriter.Scope($"private {rootServiceType} {GetResolutionServiceName(rootService.Identity)}()")) { if (rootService.Lifetime == ServiceLifetime.Singleton) { codeWriter.Append($"return "); - WriteResolutionCall(codeWriter, rootService, "_root"); + WriteResolutionCall(codeWriter, rootService.Identity, "_root"); codeWriter.Line($";"); } else @@ -324,6 +320,7 @@ private void Execute(GeneratorContext context) } WriteServiceProvider(codeWriter, root); + WriteNamedServiceProvider(codeWriter, root); if (root.KnownTypes.IServiceScopeType != null) { @@ -354,16 +351,45 @@ private void Execute(GeneratorContext context) } } + private IEnumerable> GroupNamedServices(ServiceProvider root) + { + return root.RootCallSites + .Where(static s => s.Identity.IsMainNamedImplementation) + .GroupBy(static s => s.Identity.Type, SymbolEqualityComparer.Default); + } + private void WriteNamedServiceProvider(CodeWriter codeWriter, ServiceProvider root) + { + foreach (var serviceGroup in GroupNamedServices(root)) + { + var groupType = serviceGroup.Key; + using (codeWriter.Scope($"{groupType} INamedServiceProvider<{groupType}>.GetService(string name)")) + { + using (codeWriter.Scope($"switch (name)")) + { + foreach (var callSite in serviceGroup) + { + codeWriter.Append($"case \"{callSite.Identity.Name}\": return "); + WriteResolutionCall(codeWriter, callSite.Identity, "this"); + codeWriter.Line($";"); + } + + codeWriter.Line($"default: throw CreateServiceNotFoundException<{groupType}>(name);"); + } + } + codeWriter.Line(); + } + } + private void WriteServiceProvider(CodeWriter codeWriter, ServiceProvider root) { using (codeWriter.Scope($"{typeof(object)}? {typeof(IServiceProvider)}.GetService({typeof(Type)} type)")) { foreach (var rootRootCallSite in root.RootCallSites) { - if (rootRootCallSite.IsMainImplementation) + if (rootRootCallSite.Identity.IsMainImplementation) { - codeWriter.Append($"if (type == typeof({rootRootCallSite.ServiceType})) return "); - WriteResolutionCall(codeWriter, rootRootCallSite, "this"); + codeWriter.Append($"if (type == typeof({rootRootCallSite.Identity.Type})) return "); + WriteResolutionCall(codeWriter, rootRootCallSite.Identity, "this"); codeWriter.Line($";"); } } @@ -372,6 +398,58 @@ private void WriteServiceProvider(CodeWriter codeWriter, ServiceProvider root) } codeWriter.Line(); + + WriteKeyedServiceProvider(codeWriter, root); + } + + + private void WriteKeyedServiceProvider(CodeWriter codeWriter, ServiceProvider root) + { + var iface = root.KnownTypes.IKeyedServiceProviderType; + if (iface == null) + { + return; + } + + using (codeWriter.Scope($"{typeof(object)}? {iface}.GetKeyedService({typeof(Type)} type, object? key)")) + { + foreach (var serviceGroup in GroupNamedServices(root)) + { + var serviceType = serviceGroup.Key; + using (codeWriter.Scope($"if (type == typeof({serviceType}))")) + { + using (codeWriter.Scope($"switch (key)")) + { + foreach (var callSite in serviceGroup) + { + codeWriter.Append($"case \"{callSite.Identity.Name}\": return "); + WriteResolutionCall(codeWriter, callSite.Identity, "this"); + codeWriter.Line($";"); + } + } + } + } + + codeWriter.Line($"return null;"); + } + + codeWriter.Line(); + + codeWriter.Line( + $"{typeof(object)} {iface}.GetRequiredKeyedService({typeof(Type)} type, object? key) => (({iface})this).GetKeyedService(type, key) ?? throw CreateServiceNotFoundException(type, key?.ToString());"); + + codeWriter.Line(); + } + + private void WritePublicGetServiceMethods(CodeWriter codeWriter) + { + codeWriter.Line($"[DebuggerHidden]"); + codeWriter.Line($"public T GetService() => this is IServiceProvider provider ? provider.GetService() : throw CreateServiceNotFoundException();"); + codeWriter.Line(); + + codeWriter.Line($"[DebuggerHidden]"); + codeWriter.Line($"public T GetService(string name) => this is INamedServiceProvider provider ? provider.GetService(name) : throw CreateServiceNotFoundException(name);"); + codeWriter.Line(); } private void WriteDispose(CodeWriter codeWriter, ServiceProvider root, bool isScoped) @@ -406,7 +484,7 @@ private void WriteDispose(CodeWriter codeWriter, ServiceProvider root, bool isSc (rootService.Lifetime == ServiceLifetime.Scoped && !isScoped) || rootService.Lifetime == ServiceLifetime.Transient) continue; - codeWriter.Line($"TryDispose({GetCacheLocation(rootService)});"); + codeWriter.Line($"TryDispose({GetCacheLocation(rootService.Identity)});"); } if (!isScoped) @@ -448,7 +526,7 @@ private void WriteDispose(CodeWriter codeWriter, ServiceProvider root, bool isSc (rootService.Lifetime == ServiceLifetime.Scoped && !isScoped) || rootService.Lifetime == ServiceLifetime.Transient) continue; - codeWriter.Line($"await TryDispose({GetCacheLocation(rootService)});"); + codeWriter.Line($"await TryDispose({GetCacheLocation(rootService.Identity)});"); } if (!isScoped) @@ -479,6 +557,11 @@ private static void WriteInterfaces(CodeWriter codeWriter, ServiceProvider root, codeWriter.Line($" {typeof(IServiceProvider)},"); + if (root.KnownTypes.IKeyedServiceProviderType != null) + { + codeWriter.Line($" {root.KnownTypes.IKeyedServiceProviderType},"); + } + if (!isScope && root.KnownTypes.IServiceScopeFactoryType != null) { codeWriter.Line($" {root.KnownTypes.IServiceScopeFactoryType},"); @@ -494,11 +577,23 @@ private static void WriteInterfaces(CodeWriter codeWriter, ServiceProvider root, codeWriter.Line($" {root.KnownTypes.IServiceScopeType},"); } + HashSet seenServices = new HashSet(SymbolEqualityComparer.Default); + HashSet seenNamedServices = new HashSet(SymbolEqualityComparer.Default); foreach (var serviceCallSite in root.RootCallSites) { - if (serviceCallSite.IsMainImplementation) + if (serviceCallSite.Identity.Name == null) { - codeWriter.Line($" IServiceProvider<{serviceCallSite.ServiceType}>,"); + if (seenServices.Add(serviceCallSite.Identity.Type)) + { + codeWriter.Line($" IServiceProvider<{serviceCallSite.Identity.Type}>,"); + } + } + else + { + if (seenNamedServices.Add(serviceCallSite.Identity.Type)) + { + codeWriter.Line($" INamedServiceProvider<{serviceCallSite.Identity.Type}>,"); + } } } @@ -514,32 +609,27 @@ private void WriteCacheLocations(ServiceProvider root, CodeWriter codeWriter, bo (rootService.Lifetime == ServiceLifetime.Scoped && !isScope) || rootService.Lifetime == ServiceLifetime.Transient) continue; - codeWriter.Line($"private {rootService.ImplementationType}? {GetCacheLocation(rootService)};"); + codeWriter.Line($"private {rootService.ImplementationType}? {GetCacheLocation(rootService.Identity)};"); } codeWriter.Line(); } - private string GetResolutionServiceName(ServiceCallSite serviceCallSite) + private string GetResolutionServiceName(ServiceIdentity identity) { - if (!serviceCallSite.IsMainImplementation) + if (!identity.IsMainImplementation) { - return $"Get{GetServiceExpandedName(serviceCallSite.ServiceType)}_{serviceCallSite.ReverseIndex}"; + return $"Get{GetServiceExpandedName(identity)}"; } throw new InvalidOperationException("Main implementation should be resolved via GetService call"); } - private string GetCacheLocation(ServiceCallSite serviceCallSite) + private string GetCacheLocation(ServiceIdentity identity) { - if (!serviceCallSite.IsMainImplementation) - { - return $"_{GetServiceExpandedName(serviceCallSite.ServiceType)}_{serviceCallSite.ReverseIndex}"; - } - - return $"_{GetServiceExpandedName(serviceCallSite.ServiceType)}"; + return $"_{GetServiceExpandedName(identity)}"; } - private string GetServiceExpandedName(ITypeSymbol serviceType) + private string GetServiceExpandedName(ServiceIdentity identity) { StringBuilder builder = new(); @@ -556,7 +646,19 @@ void Traverse(ITypeSymbol symbol) } } - Traverse(serviceType); + Traverse(identity.Type); + + if (identity.Name != null) + { + builder.Append("_"); + builder.Append(identity.Name); + } + + if (identity.ReverseIndex != null) + { + builder.Append("_"); + builder.Append(identity.ReverseIndex); + } return builder.ToString(); } @@ -593,6 +695,12 @@ public override void Initialize(AnalysisContext context) DiagnosticDescriptors.NoServiceTypeRegistered, DiagnosticDescriptors.ImplementationTypeAndFactoryNotAllowed, DiagnosticDescriptors.FactoryMemberMustBeAMethodOrHaveDelegateType, + DiagnosticDescriptors.ServiceNameMustBeAlphanumeric, + DiagnosticDescriptors.ImplicitIEnumerableNotNamed, + DiagnosticDescriptors.BuiltInServicesAreNotNamed, + DiagnosticDescriptors.NoServiceTypeAndNameRegistered, + DiagnosticDescriptors.NamedServiceRequiredToConstructNotRegistered, + DiagnosticDescriptors.OnlyStringKeysAreSupported, DiagnosticDescriptors.NullableServiceNotRegistered, DiagnosticDescriptors.NullableServiceRegistered, }.ToImmutableArray(); diff --git a/src/Jab/DiagnosticDescriptors.cs b/src/Jab/DiagnosticDescriptors.cs index 22bdd43..f32026d 100644 --- a/src/Jab/DiagnosticDescriptors.cs +++ b/src/Jab/DiagnosticDescriptors.cs @@ -32,7 +32,7 @@ internal static class DiagnosticDescriptors public static readonly DiagnosticDescriptor CyclicDependencyDetected = new("JAB0008", "A cyclic dependency detected when resolving a service", - "A cyclic dependency detected when resolving a service '{0}', cycle starts at service '{1}', dependency chain: '{2}'", "Usage", DiagnosticSeverity.Error, true); + "A cyclic dependency detected when resolving a service '{0}', dependency chain: '{1}'", "Usage", DiagnosticSeverity.Error, true); public static readonly DiagnosticDescriptor MissingServiceProviderAttribute = new("JAB0009", "A type contains service registrations but no ServiceProvider or ServiceProviderModule attribute", @@ -40,7 +40,7 @@ internal static class DiagnosticDescriptors public static readonly DiagnosticDescriptor NoServiceTypeRegistered = new("JAB0010", "The service registration not found", - "The service '{0}' is not registered", "Usage", DiagnosticSeverity.Error, true); + "The service type '{0}' is not registered", "Usage", DiagnosticSeverity.Error, true); public static readonly DiagnosticDescriptor ImplementationTypeAndFactoryNotAllowed = new("JAB0011", "Can't specify both the implementation type and factory/instance", @@ -50,12 +50,35 @@ internal static class DiagnosticDescriptors "The factory member has to be a method or have a delegate type", "The factory member '{0}' has to be a method of have a delegate type, for service '{1}'", "Usage", DiagnosticSeverity.Error, true); + public static readonly DiagnosticDescriptor ServiceNameMustBeAlphanumeric = new("JAB0015", + "Service name must be alphanumeric", + "Service name '{0}' must be non-empty, alphanumeric and start with a letter.", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor BuiltInServicesAreNotNamed = new("JAB0016", + "Built-in provider services can not be named", + "Built-in service '{0}' can not be named", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor ImplicitIEnumerableNotNamed = new("JAB0017", + "Implicit IEnumerable<> services can not be named", + "Implicit IEnumerable service '{0}' can not be named", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor NoServiceTypeAndNameRegistered = new("JAB0018", + "The service registration not found", + "The service type '{0}' and name '{1}' is not registered", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor NamedServiceRequiredToConstructNotRegistered = new("JAB0019", + "The named service registration not found", + "The service '{0}' with name '{1}' required to construct '{2}' is not registered", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor OnlyStringKeysAreSupported = new("JAB0020", + "Only string service keys are supported", + "Service key '{0}' is not supported, only string keys are supported", "Usage", DiagnosticSeverity.Error, true); + public static readonly DiagnosticDescriptor NullableServiceNotRegistered = new("JAB0013", "Not registered nullable dependency without a default value", "The nullable service '{0}' requested to construct '{1}' is not registered. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Error, true); public static readonly DiagnosticDescriptor NullableServiceRegistered = new("JAB0014", "Nullable dependency without a default value", - "'{0}' parameter to construct '{1}' will never be null when constructing using a service provider. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Warning, true); - + "'{0}' parameter to construct '{1}' will never be null when constructing using a service provider. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Info, true); } diff --git a/src/Jab/ErrorCallSite.cs b/src/Jab/ErrorCallSite.cs index faf4fd0..cd0d15b 100644 --- a/src/Jab/ErrorCallSite.cs +++ b/src/Jab/ErrorCallSite.cs @@ -2,7 +2,7 @@ namespace Jab; internal record ErrorCallSite : ServiceCallSite { - public ErrorCallSite(ITypeSymbol serviceType, params Diagnostic[] diagnostic) : base(serviceType, serviceType, ServiceLifetime.Transient, ReverseIndex: 0, IsDisposable: false) + public ErrorCallSite(ServiceIdentity identity, params Diagnostic[] diagnostic) : base(identity, identity.Type, ServiceLifetime.Transient, IsDisposable: false) { Diagnostic = diagnostic; } diff --git a/src/Jab/FactoryCallSite.cs b/src/Jab/FactoryCallSite.cs index 8a3929b..3a1411c 100644 --- a/src/Jab/FactoryCallSite.cs +++ b/src/Jab/FactoryCallSite.cs @@ -8,7 +8,7 @@ internal record FactoryCallSite : ServiceCallSite public ServiceCallSite[] Parameters { get; } public KeyValuePair[] OptionalParameters { get; } - public FactoryCallSite(INamedTypeSymbol serviceType, ISymbol member, MemberLocation memberLocation, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, int reverseIndex, bool? isDisposable) : base(serviceType, serviceType, lifetime, reverseIndex, isDisposable) + public FactoryCallSite(ServiceIdentity identity, ISymbol member, MemberLocation memberLocation, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, bool? isDisposable) : base(identity, identity.Type, lifetime, isDisposable) { Member = member; MemberLocation = memberLocation; diff --git a/src/Jab/GetServiceCallCandidate.cs b/src/Jab/GetServiceCallCandidate.cs index bf6a843..bde9885 100644 --- a/src/Jab/GetServiceCallCandidate.cs +++ b/src/Jab/GetServiceCallCandidate.cs @@ -4,12 +4,14 @@ internal struct GetServiceCallCandidate { public ITypeSymbol ProviderType { get; } public ITypeSymbol ServiceType { get; } + public string? ServiceName { get; } public Location? Location { get; } - public GetServiceCallCandidate(ITypeSymbol providerType, ITypeSymbol serviceType, Location? location) + public GetServiceCallCandidate(ITypeSymbol providerType, ITypeSymbol serviceType, string? serviceName, Location? location) { ProviderType = providerType; ServiceType = serviceType; + ServiceName = serviceName; Location = location; } } \ No newline at end of file diff --git a/src/Jab/KnownTypes.cs b/src/Jab/KnownTypes.cs new file mode 100644 index 0000000..9b6fce8 --- /dev/null +++ b/src/Jab/KnownTypes.cs @@ -0,0 +1,139 @@ +namespace Jab; + +internal class KnownTypes +{ + public const string JabAttributesAssemblyName = "Jab.Attributes"; + public const string TransientAttributeShortName = "Transient"; + public const string SingletonAttributeShortName = "Singleton"; + public const string ScopedAttributeShortName = "Scoped"; + public const string CompositionRootAttributeShortName = "ServiceProvider"; + public const string ServiceProviderModuleAttributeShortName = "ServiceProviderModule"; + public const string ImportAttributeShortName = "Import"; + public const string FromNamedServicesAttributeShortName = "FromNamedServices"; + + public const string TransientAttributeTypeName = $"{TransientAttributeShortName}Attribute"; + public const string SingletonAttributeTypeName = $"{SingletonAttributeShortName}Attribute"; + public const string ScopedAttributeTypeName = $"{ScopedAttributeShortName}Attribute"; + public const string CompositionRootAttributeTypeName = $"{CompositionRootAttributeShortName}Attribute"; + public const string ServiceProviderModuleAttributeTypeName = $"{ServiceProviderModuleAttributeShortName}Attribute"; + + public const string ImportAttributeTypeName = $"{ImportAttributeShortName}Attribute"; + public const string FromNamedServicesAttributeName = $"{FromNamedServicesAttributeShortName}Attribute"; + + public const string TransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}"; + public const string GenericTransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}`1"; + public const string Generic2TransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}`2"; + + public const string SingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}"; + public const string GenericSingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}`1"; + public const string Generic2SingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}`2"; + + + public const string ScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}"; + public const string GenericScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}`1"; + public const string Generic2ScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}`2"; + + public const string CompositionRootAttributeMetadataName = $"Jab.{CompositionRootAttributeTypeName}"; + public const string ServiceProviderModuleAttributeMetadataName = $"Jab.{ServiceProviderModuleAttributeTypeName}"; + + public const string ImportAttributeMetadataName = $"Jab.{ImportAttributeTypeName}"; + public const string GenericImportAttributeMetadataName = $"Jab.{ImportAttributeTypeName}`1"; + + public const string NameAttributePropertyName = "Name"; + public const string InstanceAttributePropertyName = "Instance"; + public const string FactoryAttributePropertyName = "Factory"; + public const string RootServicesAttributePropertyName = "RootServices"; + + private const string IAsyncDisposableMetadataName = "System.IAsyncDisposable"; + private const string IEnumerableMetadataName = "System.Collections.Generic.IEnumerable`1"; + private const string IServiceProviderMetadataName = "System.IServiceProvider"; + private const string IServiceScopeMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScope"; + private const string IKeyedServiceProviderMetadataName = "Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider"; + private const string FromKeyedServicesAttributeMetadataName = "Microsoft.Extensions.DependencyInjection.FromKeyedServicesAttribute"; + private const string FromNamedServicesAttributeMetadataName = $"Jab.{FromNamedServicesAttributeName}"; + + private const string IServiceScopeFactoryMetadataName = + "Microsoft.Extensions.DependencyInjection.IServiceScopeFactory"; + + private const string IServiceProviderIsServiceMetadataName = + "Microsoft.Extensions.DependencyInjection.IServiceProviderIsService"; + + public INamedTypeSymbol IEnumerableType { get; } + public INamedTypeSymbol IServiceProviderType { get; } + public INamedTypeSymbol CompositionRootAttributeType { get; } + public INamedTypeSymbol TransientAttributeType { get; } + public INamedTypeSymbol? GenericTransientAttributeType { get; } + public INamedTypeSymbol? Generic2TransientAttributeType { get; } + + public INamedTypeSymbol SingletonAttribute { get; } + public INamedTypeSymbol? GenericSingletonAttribute { get; } + public INamedTypeSymbol? Generic2SingletonAttribute { get; } + + public INamedTypeSymbol ImportAttribute { get; } + public INamedTypeSymbol? GenericImportAttribute { get; } + + public INamedTypeSymbol ModuleAttribute { get; } + public INamedTypeSymbol ScopedAttribute { get; } + public INamedTypeSymbol? GenericScopedAttribute { get; } + public INamedTypeSymbol? Generic2ScopedAttribute { get; } + public INamedTypeSymbol? IAsyncDisposableType { get; } + public INamedTypeSymbol? IServiceScopeType { get; } + public INamedTypeSymbol? IServiceScopeFactoryType { get; } + public INamedTypeSymbol? IServiceProviderIsServiceType { get; } + public INamedTypeSymbol? IKeyedServiceProviderType { get; } + public INamedTypeSymbol? FromKeyedServicesAttribute { get; } + public INamedTypeSymbol? FromNamedServicesAttribute { get; } + + public KnownTypes(Compilation compilation, IModuleSymbol module, IAssemblySymbol assemblySymbol) + { + assemblySymbol = + module.ReferencedAssemblySymbols.FirstOrDefault( + s => s.Name == JabAttributesAssemblyName) + ?? assemblySymbol; + + static INamedTypeSymbol GetTypeByMetadataNameOrThrow(IAssemblySymbol assemblySymbol, + string fullyQualifiedMetadataName) => + assemblySymbol.GetTypeByMetadataName(fullyQualifiedMetadataName) + ?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found"); + + static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation compilation, + string fullyQualifiedMetadataName) => + compilation.GetTypeByMetadataName(fullyQualifiedMetadataName) + ?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found"); + + IEnumerableType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IEnumerableMetadataName); + IServiceProviderType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IServiceProviderMetadataName); + IServiceScopeType = compilation.GetTypeByMetadataName(IServiceScopeMetadataName); + IAsyncDisposableType = compilation.GetTypeByMetadataName(IAsyncDisposableMetadataName); + IServiceScopeFactoryType = compilation.GetTypeByMetadataName(IServiceScopeFactoryMetadataName); + IServiceProviderIsServiceType = compilation.GetTypeByMetadataName(IServiceProviderIsServiceMetadataName); + IKeyedServiceProviderType = compilation.GetTypeByMetadataName(IKeyedServiceProviderMetadataName); + FromKeyedServicesAttribute = compilation.GetTypeByMetadataName(FromKeyedServicesAttributeMetadataName); + + CompositionRootAttributeType = + GetTypeByMetadataNameOrThrow(assemblySymbol, CompositionRootAttributeMetadataName); + + TransientAttributeType = GetTypeByMetadataNameOrThrow(assemblySymbol, TransientAttributeMetadataName); + GenericTransientAttributeType = assemblySymbol.GetTypeByMetadataName(GenericTransientAttributeMetadataName); + Generic2TransientAttributeType = assemblySymbol.GetTypeByMetadataName(Generic2TransientAttributeMetadataName); + + SingletonAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, SingletonAttributeMetadataName); + GenericSingletonAttribute = assemblySymbol.GetTypeByMetadataName(GenericSingletonAttributeMetadataName); + Generic2SingletonAttribute = assemblySymbol.GetTypeByMetadataName(Generic2SingletonAttributeMetadataName); + + ScopedAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ScopedAttributeMetadataName); + GenericScopedAttribute = assemblySymbol.GetTypeByMetadataName(GenericScopedAttributeMetadataName); + Generic2ScopedAttribute = assemblySymbol.GetTypeByMetadataName(Generic2ScopedAttributeMetadataName); + + ImportAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ImportAttributeMetadataName); + GenericImportAttribute = assemblySymbol.GetTypeByMetadataName(GenericImportAttributeMetadataName); + + ModuleAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ServiceProviderModuleAttributeMetadataName); + FromNamedServicesAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, FromNamedServicesAttributeMetadataName); + } + + public static bool HasKnownTypes(IModuleSymbol sourceModule) + { + return sourceModule.ReferencedAssemblySymbols.Any(s => s.Name == JabAttributesAssemblyName); + } +} diff --git a/src/Jab/MemberCallSite.cs b/src/Jab/MemberCallSite.cs index a0e6bfa..2270dac 100644 --- a/src/Jab/MemberCallSite.cs +++ b/src/Jab/MemberCallSite.cs @@ -5,7 +5,7 @@ internal record MemberCallSite : ServiceCallSite public ISymbol Member { get; } public MemberLocation MemberLocation { get; set; } - public MemberCallSite(INamedTypeSymbol serviceType, ISymbol member, MemberLocation memberLocation, ServiceLifetime lifetime, int reverseIndex, bool? isDisposable) : base(serviceType, serviceType, lifetime, reverseIndex, isDisposable) + public MemberCallSite(ServiceIdentity identity, ISymbol member, MemberLocation memberLocation, ServiceLifetime lifetime, bool? isDisposable) : base(identity, identity.Type, lifetime, isDisposable) { Member = member; MemberLocation = memberLocation; diff --git a/src/Jab/ScopeFactoryCallSite.cs b/src/Jab/ScopeFactoryCallSite.cs index f76d2f7..13ac64a 100644 --- a/src/Jab/ScopeFactoryCallSite.cs +++ b/src/Jab/ScopeFactoryCallSite.cs @@ -2,7 +2,7 @@ namespace Jab; internal record ScopeFactoryCallSite: ServiceCallSite { - public ScopeFactoryCallSite(ITypeSymbol serviceType) : base(serviceType, serviceType, ServiceLifetime.Transient, 0, false) + public ScopeFactoryCallSite(ITypeSymbol serviceType) : base(new ServiceIdentity(serviceType, null, null), serviceType, ServiceLifetime.Transient, false) { } } \ No newline at end of file diff --git a/src/Jab/ServiceCallSite.cs b/src/Jab/ServiceCallSite.cs index d5ce5db..7d80412 100644 --- a/src/Jab/ServiceCallSite.cs +++ b/src/Jab/ServiceCallSite.cs @@ -1,11 +1,9 @@ namespace Jab; -internal abstract record ServiceCallSite(ITypeSymbol ServiceType, ITypeSymbol ImplementationType, ServiceLifetime Lifetime, int ReverseIndex, bool? IsDisposable) +internal abstract record ServiceCallSite(ServiceIdentity Identity, ITypeSymbol ImplementationType, ServiceLifetime Lifetime, bool? IsDisposable) { - public ITypeSymbol ServiceType { get; } = ServiceType; + public ServiceIdentity Identity { get; } = Identity; public ITypeSymbol ImplementationType { get; } = ImplementationType; public ServiceLifetime Lifetime { get; } = Lifetime; - public int ReverseIndex { get; } = ReverseIndex; public bool? IsDisposable { get; } = IsDisposable; - public bool IsMainImplementation => ReverseIndex == 0; } \ No newline at end of file diff --git a/src/Jab/ServiceIdentity.cs b/src/Jab/ServiceIdentity.cs new file mode 100644 index 0000000..d9fa3fd --- /dev/null +++ b/src/Jab/ServiceIdentity.cs @@ -0,0 +1,18 @@ +namespace Jab; + +public readonly record struct ServiceIdentity +{ + public ITypeSymbol Type { get; } + public string? Name { get; } + public int? ReverseIndex { get; } + + public ServiceIdentity(ITypeSymbol serviceType, string? name, int? reverseIndex) + { + Type = serviceType; + Name = name; + ReverseIndex = reverseIndex == 0 ? null : reverseIndex; + } + + public bool IsMainImplementation => Name == null && ReverseIndex == null; + public bool IsMainNamedImplementation => Name != null && ReverseIndex == null; +} \ No newline at end of file diff --git a/src/Jab/ServiceProviderBuilder.cs b/src/Jab/ServiceProviderBuilder.cs index 0eaeebb..3c7cb24 100644 --- a/src/Jab/ServiceProviderBuilder.cs +++ b/src/Jab/ServiceProviderBuilder.cs @@ -1,140 +1,26 @@ namespace Jab; -internal class KnownTypes -{ - public const string JabAttributesAssemblyName = "Jab.Attributes"; - public const string TransientAttributeShortName = "Transient"; - public const string SingletonAttributeShortName = "Singleton"; - public const string ScopedAttributeShortName = "Scoped"; - public const string CompositionRootAttributeShortName = "ServiceProvider"; - public const string ServiceProviderModuleAttributeShortName = "ServiceProviderModule"; - public const string ImportAttributeShortName = "Import"; - - public const string TransientAttributeTypeName = $"{TransientAttributeShortName}Attribute"; - public const string SingletonAttributeTypeName = $"{SingletonAttributeShortName}Attribute"; - public const string ScopedAttributeTypeName = $"{ScopedAttributeShortName}Attribute"; - public const string CompositionRootAttributeTypeName = $"{CompositionRootAttributeShortName}Attribute"; - public const string ServiceProviderModuleAttributeTypeName = $"{ServiceProviderModuleAttributeShortName}Attribute"; - - public const string ImportAttributeTypeName = $"{ImportAttributeShortName}Attribute"; - - public const string TransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}"; - public const string GenericTransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}`1"; - public const string Generic2TransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}`2"; - - public const string SingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}"; - public const string GenericSingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}`1"; - public const string Generic2SingletonAttributeMetadataName = $"Jab.{SingletonAttributeTypeName}`2"; - - - public const string ScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}"; - public const string GenericScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}`1"; - public const string Generic2ScopedAttributeMetadataName = $"Jab.{ScopedAttributeTypeName}`2"; - - public const string CompositionRootAttributeMetadataName = $"Jab.{CompositionRootAttributeTypeName}"; - public const string ServiceProviderModuleAttributeMetadataName = $"Jab.{ServiceProviderModuleAttributeTypeName}"; - - public const string ImportAttributeMetadataName = $"Jab.{ImportAttributeTypeName}"; - public const string GenericImportAttributeMetadataName = $"Jab.{ImportAttributeTypeName}`1"; - - public const string InstanceAttributePropertyName = "Instance"; - public const string FactoryAttributePropertyName = "Factory"; - public const string RootServicesAttributePropertyName = "RootServices"; - - private const string IAsyncDisposableMetadataName = "System.IAsyncDisposable"; - private const string IEnumerableMetadataName = "System.Collections.Generic.IEnumerable`1"; - private const string IServiceProviderMetadataName = "System.IServiceProvider"; - private const string IServiceScopeMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScope"; - - private const string IServiceScopeFactoryMetadataName = - "Microsoft.Extensions.DependencyInjection.IServiceScopeFactory"; - - private const string IServiceProviderIsServiceMetadataName = - "Microsoft.Extensions.DependencyInjection.IServiceProviderIsService"; - - public INamedTypeSymbol IEnumerableType { get; } - public INamedTypeSymbol IServiceProviderType { get; } - public INamedTypeSymbol CompositionRootAttributeType { get; } - public INamedTypeSymbol TransientAttributeType { get; } - public INamedTypeSymbol? GenericTransientAttributeType { get; } - public INamedTypeSymbol? Generic2TransientAttributeType { get; } - - public INamedTypeSymbol SingletonAttribute { get; } - public INamedTypeSymbol? GenericSingletonAttribute { get; } - public INamedTypeSymbol? Generic2SingletonAttribute { get; } - - public INamedTypeSymbol ImportAttribute { get; } - public INamedTypeSymbol? GenericImportAttribute { get; } - - public INamedTypeSymbol ModuleAttribute { get; } - public INamedTypeSymbol ScopedAttribute { get; } - public INamedTypeSymbol? GenericScopedAttribute { get; } - public INamedTypeSymbol? Generic2ScopedAttribute { get; } - public INamedTypeSymbol? IAsyncDisposableType { get; } - public INamedTypeSymbol? IServiceScopeType { get; } - public INamedTypeSymbol? IServiceScopeFactoryType { get; } - public INamedTypeSymbol? IServiceProviderIsServiceType { get; } - - public KnownTypes(Compilation compilation, IAssemblySymbol assemblySymbol) - { - static INamedTypeSymbol GetTypeByMetadataNameOrThrow(IAssemblySymbol assemblySymbol, - string fullyQualifiedMetadataName) => - assemblySymbol.GetTypeByMetadataName(fullyQualifiedMetadataName) - ?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found"); - - static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation compilation, - string fullyQualifiedMetadataName) => - compilation.GetTypeByMetadataName(fullyQualifiedMetadataName) - ?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found"); - - IEnumerableType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IEnumerableMetadataName); - IServiceProviderType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IServiceProviderMetadataName); - IServiceScopeType = compilation.GetTypeByMetadataName(IServiceScopeMetadataName); - IAsyncDisposableType = compilation.GetTypeByMetadataName(IAsyncDisposableMetadataName); - IServiceScopeFactoryType = compilation.GetTypeByMetadataName(IServiceScopeFactoryMetadataName); - IServiceProviderIsServiceType = compilation.GetTypeByMetadataName(IServiceProviderIsServiceMetadataName); - - CompositionRootAttributeType = - GetTypeByMetadataNameOrThrow(assemblySymbol, CompositionRootAttributeMetadataName); - - TransientAttributeType = GetTypeByMetadataNameOrThrow(assemblySymbol, TransientAttributeMetadataName); - GenericTransientAttributeType = assemblySymbol.GetTypeByMetadataName(GenericTransientAttributeMetadataName); - Generic2TransientAttributeType = assemblySymbol.GetTypeByMetadataName(Generic2TransientAttributeMetadataName); - - SingletonAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, SingletonAttributeMetadataName); - GenericSingletonAttribute = assemblySymbol.GetTypeByMetadataName(GenericSingletonAttributeMetadataName); - Generic2SingletonAttribute = assemblySymbol.GetTypeByMetadataName(Generic2SingletonAttributeMetadataName); - - ScopedAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ScopedAttributeMetadataName); - GenericScopedAttribute = assemblySymbol.GetTypeByMetadataName(GenericScopedAttributeMetadataName); - Generic2ScopedAttribute = assemblySymbol.GetTypeByMetadataName(Generic2ScopedAttributeMetadataName); - - ImportAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ImportAttributeMetadataName); - GenericImportAttribute = assemblySymbol.GetTypeByMetadataName(GenericImportAttributeMetadataName); - - ModuleAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ServiceProviderModuleAttributeMetadataName); - } - - public static bool HasKnownTypes(IModuleSymbol sourceModule) - { - return sourceModule.ReferencedAssemblySymbols.Any(s => s.Name == JabAttributesAssemblyName); - } -} - internal class ServiceProviderBuilder { private readonly GeneratorContext _context; private readonly KnownTypes _knownTypes; + private readonly ServiceProviderCallSite _serviceProviderCallsite; + private readonly ScopeFactoryCallSite? _scopeFactoryCallSite; + private readonly ServiceProviderIsServiceCallSite? _serviceProviderIsServiceCallSite; public ServiceProviderBuilder(GeneratorContext context) { _context = context; - - var assemblySymbol = - context.Compilation.SourceModule.ReferencedAssemblySymbols.FirstOrDefault( - s => s.Name == KnownTypes.JabAttributesAssemblyName) - ?? context.Compilation.Assembly; - _knownTypes = new KnownTypes(context.Compilation, assemblySymbol); + _knownTypes = new KnownTypes(context.Compilation, context.Compilation.SourceModule, context.Compilation.Assembly); + _serviceProviderCallsite = new ServiceProviderCallSite(_knownTypes.IServiceProviderType); + if (_knownTypes.IServiceScopeFactoryType != null) + { + _scopeFactoryCallSite = new ScopeFactoryCallSite(_knownTypes.IServiceScopeFactoryType); + } + if (_knownTypes.IServiceProviderIsServiceType != null) + { + _serviceProviderIsServiceCallSite = new ServiceProviderIsServiceCallSite(_knownTypes.IServiceProviderIsServiceType); + } } public ServiceProvider[] BuildRoots() @@ -146,14 +32,18 @@ public ServiceProvider[] BuildRoots() var semanticModel = _context.Compilation.GetSemanticModel(candidateGetServiceCallGroup.Key); foreach (var candidateGetServiceCall in candidateGetServiceCallGroup) { - if (candidateGetServiceCall.Expression is MemberAccessExpressionSyntax + if (candidateGetServiceCall is { - Name: GenericNameSyntax + Expression: MemberAccessExpressionSyntax { - IsUnboundGenericName: false, - TypeArgumentList: { Arguments: { Count: 1 } arguments } - } - } memberAccessExpression) + Name: GenericNameSyntax + { + IsUnboundGenericName: false, + TypeArgumentList: { Arguments: { Count: 1 } arguments } + } + } memberAccessExpression + } + ) { var containerTypeInfo = semanticModel.GetTypeInfo(memberAccessExpression.Expression); var serviceInfo = semanticModel.GetSymbolInfo(arguments[0]); @@ -162,7 +52,25 @@ serviceInfo.Symbol is ITypeSymbol serviceType && serviceType.TypeKind != TypeKind.TypeParameter ) { - getServiceCallCandidates.Add(new GetServiceCallCandidate(containerTypeInfo.Type, serviceType, + string? serviceName = null; + var invocationArguments = candidateGetServiceCall.ArgumentList.Arguments; + if (invocationArguments.Count == 1) + { + if (invocationArguments[0].Expression is LiteralExpressionSyntax { } literal && + literal.Token.IsKind(SyntaxKind.StringLiteralToken)) + { + serviceName = literal.Token.ValueText; + } + else + { + // Service name is dynamic, can't do anything + continue; + } + } + getServiceCallCandidates.Add(new GetServiceCallCandidate( + containerTypeInfo.Type, + serviceType, + serviceName, candidateGetServiceCall.GetLocation())); } } @@ -203,7 +111,7 @@ private bool TryCreateCompositionRoot(ITypeSymbol typeSymbol, EmitTypeDiagnostics(typeSymbol); - Dictionary callSites = new(); + CallSiteCache callSites = new(); foreach (var registration in description.ServiceRegistrations) { if (registration.ServiceType.IsUnboundGenericType) @@ -211,8 +119,10 @@ private bool TryCreateCompositionRoot(ITypeSymbol typeSymbol, continue; } - GetCallSite(registration.ServiceType, - new ServiceResolutionContext(description, callSites, registration.ServiceType, registration.Location)); + GetCallSite( + registration.ServiceType, + registration.Name, + new ServiceResolutionContext(description, callSites, registration.Location)); } List rootServices = new(description.RootServices); @@ -229,8 +139,10 @@ private bool TryCreateCompositionRoot(ITypeSymbol typeSymbol, foreach (var rootService in rootServices) { var serviceType = rootService.Service; - var callSite = GetCallSite(serviceType, - new ServiceResolutionContext(description, callSites, serviceType, description.Location)); + var callSite = GetCallSite( + serviceType, + null, + new ServiceResolutionContext(description, callSites, description.Location)); if (callSite == null) { _context.ReportDiagnostic(Diagnostic.Create( @@ -246,21 +158,34 @@ private bool TryCreateCompositionRoot(ITypeSymbol typeSymbol, if (SymbolEqualityComparer.Default.Equals(getServiceCallCandidate.ProviderType, typeSymbol)) { var serviceType = getServiceCallCandidate.ServiceType; - var callSite = GetCallSite(serviceType, - new ServiceResolutionContext(description, callSites, serviceType, - getServiceCallCandidate.Location)); + var callSite = GetCallSite( + serviceType, + getServiceCallCandidate.ServiceName, + new ServiceResolutionContext(description, callSites, getServiceCallCandidate.Location)); if (callSite == null) { - _context.ReportDiagnostic(Diagnostic.Create( - DiagnosticDescriptors.NoServiceTypeRegistered, - getServiceCallCandidate.Location, - serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat) - )); + if (getServiceCallCandidate.ServiceName == null) + { + _context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.NoServiceTypeRegistered, + getServiceCallCandidate.Location, + serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat) + )); + } + else + { + _context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.NoServiceTypeAndNameRegistered, + getServiceCallCandidate.Location, + serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), + getServiceCallCandidate.ServiceName + )); + } } } } - compositionRoot = new ServiceProvider(typeSymbol, callSites.Values.ToArray(), _knownTypes); + compositionRoot = new ServiceProvider(typeSymbol, callSites.GetRootCallSites(), _knownTypes); return true; } @@ -282,62 +207,77 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol) private ServiceCallSite? GetCallSite( ITypeSymbol serviceType, + string? name, ServiceResolutionContext context) { - if (context.CallSiteCache.TryGetValue(new CallSiteCacheKey(serviceType), out var cachedCallSite)) - { - return cachedCallSite; - } - if (!context.TryAdd(serviceType)) { var diagnostic = Diagnostic.Create(DiagnosticDescriptors.CyclicDependencyDetected, context.RequestLocation, - context.RequestService.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), context.ToString(serviceType)); _context.ReportDiagnostic( diagnostic); - return new ErrorCallSite(serviceType, diagnostic); + return new ErrorCallSite(new ServiceIdentity(serviceType, name, null), diagnostic); } try { - return TryCreateSpecial(serviceType, context) ?? - TryCreateExact(serviceType, context) ?? - TryCreateEnumerable(serviceType, context) ?? - TryCreateGeneric(serviceType, context); + return TryCreateSpecial(serviceType, name, context) ?? + TryCreateExact(serviceType, name, null, context) ?? + TryCreateEnumerable(serviceType, name, context) ?? + TryCreateGeneric(serviceType, name, context); } - catch + finally { context.Remove(serviceType); - throw; } } - private ServiceCallSite? TryCreateSpecial(ITypeSymbol serviceType, ServiceResolutionContext context) + private ServiceCallSite? TryCreateSpecial(ITypeSymbol serviceType, string? name, ServiceResolutionContext context) { - if (SymbolEqualityComparer.Default.Equals(serviceType, _knownTypes.IServiceProviderType)) + ErrorCallSite? CheckNotNamed(ServiceIdentity identity) { - var callSite = new ServiceProviderCallSite(serviceType); - context.CallSiteCache[new CallSiteCacheKey(serviceType)] = callSite; + if (name == null) return null; + + var diagnostic = Diagnostic.Create(DiagnosticDescriptors.BuiltInServicesAreNotNamed, + context.RequestLocation, + serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + + _context.ReportDiagnostic( + diagnostic); + + return new ErrorCallSite(identity); + } + + ServiceCallSite BuiltInCallSite(ServiceCallSite callSite) + { + if (CheckNotNamed(callSite.Identity) is { } error) + { + return error; + } + if (!context.CallSiteCache.TryGet(callSite.Identity, out _)) + { + context.CallSiteCache.Add(callSite); + } return callSite; } + if (SymbolEqualityComparer.Default.Equals(serviceType, _knownTypes.IServiceProviderType)) + { + return BuiltInCallSite(_serviceProviderCallsite); + } + if (SymbolEqualityComparer.Default.Equals(serviceType, _knownTypes.IServiceScopeFactoryType)) { - var callSite = new ScopeFactoryCallSite(serviceType); - context.CallSiteCache[new CallSiteCacheKey(serviceType)] = callSite; - return callSite; + return BuiltInCallSite(_scopeFactoryCallSite!); } if (SymbolEqualityComparer.Default.Equals(serviceType, _knownTypes.IServiceProviderIsServiceType)) { - var callSite = new ServiceProviderIsServiceCallSite(serviceType); - context.CallSiteCache[new CallSiteCacheKey(serviceType)] = callSite; - return callSite; + return BuiltInCallSite(_serviceProviderIsServiceCallSite!); } return null; @@ -345,6 +285,7 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol) private ServiceCallSite? TryCreateGeneric( ITypeSymbol serviceType, + string? name, ServiceResolutionContext context) { if (serviceType is INamedTypeSymbol { IsGenericType: true }) @@ -353,7 +294,7 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol) { var registration = context.ProviderDescription.ServiceRegistrations[i]; - var callSite = TryCreateGeneric(serviceType, registration, 0, context); + var callSite = TryMatchGeneric(serviceType, name, null, registration, context); if (callSite != null) { return callSite; @@ -364,53 +305,61 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol) return null; } - private ServiceCallSite? TryCreateGeneric( + private ServiceCallSite? TryMatchGeneric( ITypeSymbol serviceType, + string? name, + int? reverseIndex, ServiceRegistration registration, - int reverseIndex, ServiceResolutionContext context) { - if (serviceType is INamedTypeSymbol { IsGenericType: true } genericType && + if (registration.Name == name && + serviceType is INamedTypeSymbol { IsGenericType: true } genericType && registration.ServiceType.IsUnboundGenericType && SymbolEqualityComparer.Default.Equals(registration.ServiceType.ConstructedFrom, genericType.ConstructedFrom)) { + var identity = new ServiceIdentity(serviceType, name, reverseIndex); + if (context.CallSiteCache.TryGet(identity, out var callSite)) + { + return callSite; + } + // TODO: This can use better error reporting if (registration.FactoryMember is IMethodSymbol factoryMethod) { var constructedFactoryMethod = factoryMethod.ConstructedFrom.Construct(genericType.TypeArguments, genericType.TypeArgumentNullableAnnotations); - var callSite = CreateFactoryCallSite( + callSite = CreateFactoryCallSite( + identity, genericType, - null, registration.Lifetime, registration.Location, memberLocation: registration.MemberLocation, factoryMember: constructedFactoryMethod, - reverseIndex: reverseIndex, context: context); - - context.CallSiteCache[new CallSiteCacheKey(reverseIndex, serviceType)] = callSite; - - return callSite; } else if (registration.ImplementationType != null) { var implementationType = registration.ImplementationType.ConstructedFrom.Construct(genericType.TypeArguments, genericType.TypeArgumentNullableAnnotations); - return CreateConstructorCallSite(registration, genericType, implementationType, reverseIndex, context); + + callSite = CreateConstructorCallSite(identity, registration, implementationType, context); } else { throw new InvalidOperationException($"Can't construct generic callsite for {serviceType}"); } + + context.CallSiteCache.Add(callSite); + + return callSite; } return null; } - private ServiceCallSite? TryCreateEnumerable(ITypeSymbol serviceType, ServiceResolutionContext context) + private ServiceCallSite? TryCreateEnumerable(ITypeSymbol serviceType, string? name, ServiceResolutionContext context) { static ServiceLifetime GetCommonLifetime(IEnumerable callSites) { @@ -430,6 +379,24 @@ static ServiceLifetime GetCommonLifetime(IEnumerable callSites) if (serviceType is INamedTypeSymbol { IsGenericType: true } genericType && SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.IEnumerableType)) { + var identity = new ServiceIdentity(genericType, null, null); + + if (name != null) + { + var diagnostic = Diagnostic.Create(DiagnosticDescriptors.ImplicitIEnumerableNotNamed, + context.RequestLocation, + serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + + _context.ReportDiagnostic( + diagnostic); + return new ErrorCallSite(identity, diagnostic); + } + + if (context.CallSiteCache.TryGet(identity, out var callSite)) + { + return callSite; + } + var enumerableService = genericType.TypeArguments[0]; var items = new List(); int reverseIndex = 0; @@ -437,8 +404,8 @@ static ServiceLifetime GetCommonLifetime(IEnumerable callSites) { var registration = context.ProviderDescription.ServiceRegistrations[i]; - var itemCallSite = TryCreateGeneric(enumerableService, registration, reverseIndex, context) ?? - TryCreateExact(registration, enumerableService, reverseIndex, context); + var itemCallSite = TryMatchGeneric(enumerableService, null, reverseIndex, registration, context) ?? + TryMatchExact(enumerableService, null, reverseIndex, registration, context); if (itemCallSite != null) { reverseIndex++; @@ -448,15 +415,16 @@ static ServiceLifetime GetCommonLifetime(IEnumerable callSites) var serviceCallSites = items.ToArray(); Array.Reverse(serviceCallSites); - var callSite = new ArrayServiceCallSite( - genericType, + Debug.Assert(name == null); + callSite = new ArrayServiceCallSite( + identity, genericType, enumerableService, serviceCallSites, // Pick a most common lifetime GetCommonLifetime(items)); - context.CallSiteCache[new CallSiteCacheKey(reverseIndex, serviceType)] = callSite; + context.CallSiteCache.Add(callSite); return callSite; } @@ -464,94 +432,87 @@ static ServiceLifetime GetCommonLifetime(IEnumerable callSites) return null; } - private ServiceCallSite? TryCreateExact(ITypeSymbol serviceType, ServiceResolutionContext context) + private ServiceCallSite? TryCreateExact( + ITypeSymbol serviceType, + string? name, + int? reverseIndex, + ServiceResolutionContext context) { - if (context.ProviderDescription.ServiceRegistrationsLookup.TryGetValue(serviceType, out var registration)) + if (!context.ProviderDescription.ServiceRegistrationsLookup.TryGetValue(serviceType, out var registrations)) { - return CreateCallSite(registration, reverseIndex: 0, context: context); + return null; } - return null; - } - - private ServiceCallSite? TryCreateExact(ServiceRegistration registration, ITypeSymbol serviceType, int reverseIndex, - ServiceResolutionContext context) - { - if (SymbolEqualityComparer.Default.Equals(registration.ServiceType, serviceType)) + for (var index = registrations.Count - 1; index >= 0; index--) { - return CreateCallSite(registration, reverseIndex: reverseIndex, context: context); + var callSite = TryMatchExact(serviceType, name, reverseIndex, registrations[index], context: context); + if (callSite != null) + { + return callSite; + } } return null; } - private ServiceCallSite CreateCallSite( + private ServiceCallSite? TryMatchExact( + ITypeSymbol serviceType, + string? name, + int? reverseIndex, ServiceRegistration registration, - int reverseIndex, ServiceResolutionContext context) { - var cacheKey = new CallSiteCacheKey(reverseIndex, registration.ServiceType); - - if (context.CallSiteCache.TryGetValue(cacheKey, out ServiceCallSite callSite)) - { - return callSite; - } - - if (registration.InstanceMember is { } instanceMember) - { - callSite = CreateMemberCallSite( - registration, - instanceMember, - registration.MemberLocation, - reverseIndex); - } - else if (registration.FactoryMember is { } factoryMember) - { - callSite = CreateFactoryCallSite( - registration.ServiceType, - registration.ImplementationType, - registration.Lifetime, - registration.Location, - registration.MemberLocation, - factoryMember, - reverseIndex, - context); - } - else + if ( + registration.Name == name && + SymbolEqualityComparer.Default.Equals(registration.ServiceType, serviceType)) { - var implementationType = registration.ImplementationType ?? - registration.ServiceType; + var identity = new ServiceIdentity(registration.ServiceType, registration.Name, reverseIndex); + if (context.CallSiteCache.TryGet(identity, out ServiceCallSite callSite)) + { + return callSite; + } - callSite = CreateConstructorCallSite(registration, registration.ServiceType, implementationType, - reverseIndex, context); - } + if (registration.InstanceMember is { } instanceMember) + { + callSite = new MemberCallSite(identity, + instanceMember, + memberLocation: registration.MemberLocation, + registration.Lifetime, + false); + } + else if (registration.FactoryMember is { } factoryMember) + { + callSite = CreateFactoryCallSite( + identity, + registration.ImplementationType, + registration.Lifetime, + registration.Location, + registration.MemberLocation, + factoryMember, + context); + } + else + { + var implementationType = registration.ImplementationType ?? + registration.ServiceType; - context.CallSiteCache[cacheKey] = callSite; + callSite = CreateConstructorCallSite(identity, registration, implementationType, context); + } - return callSite; - } + context.CallSiteCache.Add(callSite); + return callSite; + } - private ServiceCallSite CreateMemberCallSite( - ServiceRegistration registration, - ISymbol instanceMember, - MemberLocation memberLocation, - int reverseIndex) - { - return new MemberCallSite(registration.ServiceType, - instanceMember, - memberLocation: memberLocation, - registration.Lifetime, - reverseIndex, - false); + return null; } - private ServiceCallSite CreateFactoryCallSite(INamedTypeSymbol serviceType, - INamedTypeSymbol? implementationType, + private ServiceCallSite CreateFactoryCallSite( + ServiceIdentity identity, + ITypeSymbol? implementationType, ServiceLifetime lifetime, Location? registrationLocation, MemberLocation memberLocation, ISymbol factoryMember, - int reverseIndex, ServiceResolutionContext context) { ImmutableArray GetDelegateParameters(ITypeSymbol type) @@ -567,14 +528,12 @@ ImmutableArray GetDelegateParameters(ITypeSymbol type) throw new InvalidOperationException($"Unable to determine parameters for {type.ToDisplayString()}"); } - var cacheKey = new CallSiteCacheKey(reverseIndex, serviceType); - - if (context.CallSiteCache.TryGetValue(cacheKey, out ServiceCallSite callSite)) + if (context.CallSiteCache.TryGet(identity, out ServiceCallSite callSite)) { return callSite; } - implementationType ??= serviceType; + implementationType ??= identity.Type; ImmutableArray factoryParameters; switch (factoryMember) @@ -593,9 +552,9 @@ ImmutableArray GetDelegateParameters(ITypeSymbol type) DiagnosticDescriptors.FactoryMemberMustBeAMethodOrHaveDelegateType, ExtractMemberTypeLocation(factoryMember), factoryMember.Name, - serviceType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + identity.Type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); _context.ReportDiagnostic(diagnostic); - return new ErrorCallSite(serviceType, diagnostic); + return new ErrorCallSite(identity, diagnostic); } var (parameters, namedParameters, diagnostics) = @@ -603,33 +562,26 @@ ImmutableArray GetDelegateParameters(ITypeSymbol type) if (diagnostics.Count > 0) { - return new ErrorCallSite(serviceType, diagnostics.ToArray()); + return new ErrorCallSite(identity, diagnostics.ToArray()); } - return new FactoryCallSite(serviceType, + var factoryCallSite = new FactoryCallSite(identity, factoryMember, memberLocation: memberLocation, parameters.ToArray(), namedParameters.ToArray(), lifetime, - reverseIndex, false); + + return factoryCallSite; } private ServiceCallSite CreateConstructorCallSite( + ServiceIdentity identity, ServiceRegistration registration, - INamedTypeSymbol serviceType, INamedTypeSymbol implementationType, - int reverseIndex, ServiceResolutionContext context) { - var cacheKey = new CallSiteCacheKey(reverseIndex, serviceType); - - if (context.CallSiteCache.TryGetValue(cacheKey, out ServiceCallSite callSite)) - { - return callSite; - } - context.TryAdd(implementationType); try { @@ -643,7 +595,7 @@ private ServiceCallSite CreateConstructorCallSite( _context.ReportDiagnostic( diagnostic); - return new ErrorCallSite(serviceType, diagnostic); + return new ErrorCallSite(identity, diagnostic); } var (parameters, namedParameters, diagnostics) = @@ -651,28 +603,23 @@ private ServiceCallSite CreateConstructorCallSite( if (diagnostics.Count > 0) { - return new ErrorCallSite(serviceType, diagnostics.ToArray()); + return new ErrorCallSite(identity, diagnostics.ToArray()); } - callSite = new ConstructorCallSite( - serviceType, + return new ConstructorCallSite( + identity, implementationType, parameters.ToArray(), namedParameters.ToArray(), registration.Lifetime, - reverseIndex, + identity.ReverseIndex, // TODO: this can be optimized to avoid check for all the types isDisposable: null ); - - context.CallSiteCache[cacheKey] = callSite; - - return callSite; } - catch + finally { context.Remove(implementationType); - throw; } } @@ -682,7 +629,7 @@ private ServiceCallSite CreateConstructorCallSite( List Diagnostics) GetParameters( ImmutableArray parameters, Location? registrationLocation, - INamedTypeSymbol implementationType, + ITypeSymbol implementationType, ServiceResolutionContext context) { var callSites = new List(); @@ -690,7 +637,37 @@ private ServiceCallSite CreateConstructorCallSite( var diagnostics = new List(); foreach (var parameterSymbol in parameters) { - var parameterCallSite = GetCallSite(parameterSymbol.Type, context); + string? registrationName = null; + foreach (var attributeData in parameterSymbol.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, + _knownTypes.FromNamedServicesAttribute)) + { + registrationName = (string?)attributeData.ConstructorArguments[0].Value; + ValidateServiceName(registrationName, attributeData); + } + + if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, + _knownTypes.FromKeyedServicesAttribute)) + { + var key = attributeData.ConstructorArguments[0].Value; + if (key is not string) + { + var diagnostic = Diagnostic.Create( + DiagnosticDescriptors.OnlyStringKeysAreSupported, + attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation(), + key); + _context.ReportDiagnostic(diagnostic); + } + registrationName = Convert.ToString(key); + ValidateServiceName(registrationName, attributeData); + } + } + + var parameterCallSite = GetCallSite( + parameterSymbol.Type, + registrationName, + context.WithRequestLocation(ExtractMemberTypeLocation(parameterSymbol))); if (parameterSymbol.IsOptional) { if (parameterCallSite != null) @@ -704,11 +681,23 @@ private ServiceCallSite CreateConstructorCallSite( bool isNullable = parameterSymbol.Type.NullableAnnotation == NullableAnnotation.Annotated; if (parameterCallSite == null) { - var diagnostic = Diagnostic.Create( - isNullable ? DiagnosticDescriptors.NullableServiceNotRegistered : DiagnosticDescriptors.ServiceRequiredToConstructNotRegistered, - registrationLocation, - parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), - implementationType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + Diagnostic diagnostic; + if (registrationName == null) + { + diagnostic = Diagnostic.Create( + isNullable ? DiagnosticDescriptors.NullableServiceNotRegistered : DiagnosticDescriptors.ServiceRequiredToConstructNotRegistered, + registrationLocation, + parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), + implementationType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + } + else + { + diagnostic = Diagnostic.Create(DiagnosticDescriptors.NamedServiceRequiredToConstructNotRegistered, + registrationLocation, + parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat), + registrationName, + implementationType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + } diagnostics.Add(diagnostic); _context.ReportDiagnostic(diagnostic); @@ -873,18 +862,10 @@ private void ProcessModule(ITypeSymbol serviceProviderType, List s.Name == KnownTypes.JabAttributesAssemblyName) - ?? moduleType.ContainingAssembly; - knownTypes = new KnownTypes(_context.Compilation, assemblySymbol); - } + var knownTypes = + SymbolEqualityComparer.Default.Equals(moduleType.ContainingAssembly, _context.Compilation.Assembly) + ? _knownTypes + : new KnownTypes(_context.Compilation, moduleType.ContainingModule, moduleType.ContainingAssembly); foreach (var attributeData in moduleType.GetAttributes()) { @@ -990,17 +971,23 @@ private bool TryCreateRegistration( { registration = null; + string? registrationName = null; string? instanceMemberName = null; string? factoryMemberName = null; foreach (var namedArgument in attributeData.NamedArguments) { - if (namedArgument.Key == KnownTypes.InstanceAttributePropertyName) - { - instanceMemberName = (string?)namedArgument.Value.Value; - } - else if (namedArgument.Key == KnownTypes.FactoryAttributePropertyName) + switch (namedArgument.Key) { - factoryMemberName = (string?)namedArgument.Value.Value; + case KnownTypes.NameAttributePropertyName: + registrationName = (string?)namedArgument.Value.Value; + ValidateServiceName(registrationName, attributeData); + break; + case KnownTypes.InstanceAttributePropertyName: + instanceMemberName = (string?)namedArgument.Value.Value; + break; + case KnownTypes.FactoryAttributePropertyName: + factoryMemberName = (string?)namedArgument.Value.Value; + break; } } @@ -1063,6 +1050,7 @@ private bool TryCreateRegistration( registration = new ServiceRegistration( serviceLifetime, serviceType, + registrationName, implementationType, instanceMember, factoryMember, @@ -1091,7 +1079,7 @@ private bool TryFindMember(ITypeSymbol typeSymbol, { members.AddRange(moduleType.GetMembers(memberName)); } - + if (members.Count == 0) { members.AddRange(typeSymbol.GetMembers(memberName)); @@ -1138,6 +1126,32 @@ private bool TryFindMember(ITypeSymbol typeSymbol, return true; } + private void ValidateServiceName(string? name, AttributeData attributeData) + { + if (name == null) + { + return; + } + + bool badName = name == "" || !char.IsLetter(name[0]); + + foreach (var c in name) + { + if (!char.IsLetterOrDigit(c)) + { + badName = true; + } + } + + if (badName) + { + _context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.ServiceNameMustBeAlphanumeric, + attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation(), + name)); + } + } + private INamedTypeSymbol ExtractType(TypedConstant typedConstant) { if (typedConstant.Kind != TypedConstantKind.Type) @@ -1160,6 +1174,7 @@ private INamedTypeSymbol ExtractType(TypedConstant typedConstant) var syntax = declaringSyntaxReference.GetSyntax(); return syntax switch { + ParameterSyntax { Type: {} type } => type.GetLocation(), PropertyDeclarationSyntax declarationSyntax => declarationSyntax.Type.GetLocation(), FieldDeclarationSyntax fieldDeclarationSyntax => fieldDeclarationSyntax.Declaration.Type.GetLocation(), VariableDeclaratorSyntax { Parent: VariableDeclarationSyntax { Type: {} type }} => type.GetLocation(), @@ -1186,71 +1201,58 @@ private INamedTypeSymbol ExtractType(TypedConstant typedConstant) return null; } - private readonly struct CallSiteCacheKey : IEquatable + private class CallSiteCache { - public CallSiteCacheKey(ITypeSymbol type) : this(0, type) - { - } - - public CallSiteCacheKey(int reverseIndex, ITypeSymbol type) + private Dictionary _callSites = new(); + public bool TryGet(ServiceIdentity identity, out ServiceCallSite callSite) { - ReverseIndex = reverseIndex; - Type = type; + return _callSites.TryGetValue(identity, out callSite); } - public int ReverseIndex { get; } - public ITypeSymbol Type { get; } - - public bool Equals(CallSiteCacheKey other) + public void Add(ServiceCallSite callSite) { - return ReverseIndex == other.ReverseIndex && SymbolEqualityComparer.Default.Equals(Type, other.Type); + _callSites.Add(callSite.Identity, callSite); } - public override bool Equals(object? obj) + public ServiceCallSite[] GetRootCallSites() { - return obj is CallSiteCacheKey other && Equals(other); - } - - public override int GetHashCode() - { - unchecked - { - return (ReverseIndex * 397) ^ SymbolEqualityComparer.Default.GetHashCode(Type); - } + return _callSites.Values.ToArray(); } } - private class ServiceResolutionContext { - private readonly HashSet _chain = new(); - private int _index; + private HashSet _chain = new(); - public Dictionary CallSiteCache { get; } - public ITypeSymbol RequestService { get; } + public CallSiteCache CallSiteCache { get; } public ServiceProviderDescription ProviderDescription { get; } public Location? RequestLocation { get; } public ServiceResolutionContext( ServiceProviderDescription providerDescription, - Dictionary serviceCallSites, - ITypeSymbol requestService, + CallSiteCache serviceCallSites, Location? requestLocation) { CallSiteCache = serviceCallSites; - RequestService = requestService; ProviderDescription = providerDescription; RequestLocation = requestLocation; } + public ServiceResolutionContext WithRequestLocation(Location? requestLocation) + { + return new ServiceResolutionContext(ProviderDescription, CallSiteCache, requestLocation) + { + _chain = this._chain + }; + } + public bool TryAdd(ITypeSymbol typeSymbol) { - var serviceChainItem = new ServiceChainItem(typeSymbol, _index); + var serviceChainItem = new ServiceChainItem(typeSymbol, _chain.Count); if (_chain.Contains(serviceChainItem)) { return false; } - _index++; _chain.Add(serviceChainItem); return true; } @@ -1258,7 +1260,6 @@ public bool TryAdd(ITypeSymbol typeSymbol) public void Remove(ITypeSymbol typeSymbol) { _chain.Remove(new ServiceChainItem(typeSymbol, 0)); - _index--; } public string ToString(ITypeSymbol typeSymbol) diff --git a/src/Jab/ServiceProviderCallSite.cs b/src/Jab/ServiceProviderCallSite.cs index dd2050f..2bd4a65 100644 --- a/src/Jab/ServiceProviderCallSite.cs +++ b/src/Jab/ServiceProviderCallSite.cs @@ -2,7 +2,7 @@ namespace Jab; internal record ServiceProviderCallSite: ServiceCallSite { - public ServiceProviderCallSite(ITypeSymbol serviceType) : base(serviceType, serviceType, ServiceLifetime.Transient, 0, false) + public ServiceProviderCallSite(ITypeSymbol serviceType) : base(new ServiceIdentity(serviceType, null, null), serviceType, ServiceLifetime.Transient, false) { } } \ No newline at end of file diff --git a/src/Jab/ServiceProviderDescription.cs b/src/Jab/ServiceProviderDescription.cs index 0e65c50..d887e58 100644 --- a/src/Jab/ServiceProviderDescription.cs +++ b/src/Jab/ServiceProviderDescription.cs @@ -7,15 +7,21 @@ public ServiceProviderDescription(IReadOnlyList serviceRegi Location = location; RootServices = rootServices; ServiceRegistrations = serviceRegistrations; - ServiceRegistrationsLookup = new Dictionary(SymbolEqualityComparer.Default); + ServiceRegistrationsLookup = new(SymbolEqualityComparer.Default); foreach (var registration in serviceRegistrations) { - ServiceRegistrationsLookup[registration.ServiceType] = registration; + ServiceRegistrationsLookup.TryGetValue(registration.ServiceType, out var registrations); + + if (registrations == null) + { + registrations = ServiceRegistrationsLookup[registration.ServiceType] = new(); + } + registrations.Add(registration); } } - public Dictionary ServiceRegistrationsLookup { get; } + public Dictionary> ServiceRegistrationsLookup { get; } public Location? Location { get; } public RootService[] RootServices { get; } diff --git a/src/Jab/ServiceProviderIsServiceCallSite.cs b/src/Jab/ServiceProviderIsServiceCallSite.cs index b0ded8d..cffdc27 100644 --- a/src/Jab/ServiceProviderIsServiceCallSite.cs +++ b/src/Jab/ServiceProviderIsServiceCallSite.cs @@ -2,7 +2,7 @@ namespace Jab; internal record ServiceProviderIsServiceCallSite: ServiceCallSite { - public ServiceProviderIsServiceCallSite(ITypeSymbol serviceType) : base(serviceType, serviceType, ServiceLifetime.Transient, 0, false) + public ServiceProviderIsServiceCallSite(ITypeSymbol serviceType) : base(new ServiceIdentity(serviceType, null, null), serviceType, ServiceLifetime.Transient, false) { } } \ No newline at end of file diff --git a/src/Jab/ServiceRegistration.cs b/src/Jab/ServiceRegistration.cs index dd10cf1..81b02b5 100644 --- a/src/Jab/ServiceRegistration.cs +++ b/src/Jab/ServiceRegistration.cs @@ -1,8 +1,17 @@ namespace Jab; -internal record ServiceRegistration(ServiceLifetime Lifetime, INamedTypeSymbol ServiceType, INamedTypeSymbol? ImplementationType, ISymbol? InstanceMember, ISymbol? FactoryMember, Location? Location, MemberLocation MemberLocation) +internal record ServiceRegistration( + ServiceLifetime Lifetime, + INamedTypeSymbol ServiceType, + string? Name, + INamedTypeSymbol? ImplementationType, + ISymbol? InstanceMember, + ISymbol? FactoryMember, + Location? Location, + MemberLocation MemberLocation) { public INamedTypeSymbol ServiceType { get; } = ServiceType; + public string? Name { get; } = Name; public INamedTypeSymbol? ImplementationType { get; } = ImplementationType; public ISymbol? InstanceMember { get; } = InstanceMember; public ISymbol? FactoryMember { get; } = FactoryMember;