Skip to content

Commit

Permalink
MultiTenant entities - Get tenant from context only if there is any m…
Browse files Browse the repository at this point in the history
…odified multitenant entity (#280)

* MultiTenant entities - Get tenant from context only if there is any modified multitenant entity

* MultiTenant entities - Get tenant from context only if there is any modified multitenant entity

* MultiTenant entities - Get tenant from context only if there is any modified multitenant entity

* MultiTenant entities - Get tenant from context only if there is any modified multitenant entity

* MultiTenant entities - Get tenant from context only if there is any modified multitenant entity

---------

Co-authored-by: Tudor Stanciu <tstanciu@totalsoft.ro>
  • Loading branch information
tstanciu and Tudor Stanciu authored Feb 6, 2025
1 parent 77d4c9b commit 9735286
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ public static class DbContextExtensions
{
public static void SetTenantIdFromContext(this DbContext context)
{
var tenantId = context.GetTenantIdFromContext();

var multiTenantEntities =
context.ChangeTracker.Entries()
.Where(e => e.IsMultiTenant() && e.State != EntityState.Unchanged);

if (!multiTenantEntities.Any())
{
return;
}

var tenantId = context.GetTenantIdFromContext();
foreach (var e in multiTenantEntities)
{
var attemptedTenantId = e.GetTenantId();
Expand All @@ -41,4 +45,4 @@ public static void UseMultitenancy(this DbContextOptionsBuilder options, IServic
((IDbContextOptionsBuilderInfrastructure)options).AddOrUpdateExtension(extension);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace NBB.MultiTenancy.Abstractions.Context
{
public static class TenantContextExtensions
{
public static Guid GetTenantId(this TenantContext tenantContext) => tenantContext.Tenant?.TenantId ?? throw new TenantNotFoundException();
public static Guid? TryGetTenantId(this TenantContext tenantContext) => tenantContext.Tenant?.TenantId;
public static string GetTenantCode(this TenantContext tenantContext) => tenantContext.Tenant?.Code ?? throw new TenantNotFoundException();
public static Guid GetTenantId(this TenantContext tenantContext) => tenantContext?.Tenant?.TenantId ?? throw new TenantNotFoundException();
public static Guid? TryGetTenantId(this TenantContext tenantContext) => tenantContext?.Tenant?.TenantId;
public static string GetTenantCode(this TenantContext tenantContext) => tenantContext?.Tenant?.Code ?? throw new TenantNotFoundException();

public static TenantContextFlow ChangeTenantContext(this ITenantContextAccessor tenantContextAccessor, Tenant tenant)
=> tenantContextAccessor.ChangeTenantContext(new TenantContext(tenant));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
Expand All @@ -15,6 +11,10 @@
using NBB.MultiTenancy.Abstractions.Configuration;
using NBB.MultiTenancy.Abstractions.Context;
using NBB.MultiTenancy.Abstractions.Options;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
Expand All @@ -26,7 +26,7 @@ public async Task Should_add_tenantId()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };

await WithTenantScope(sp, testTenantId, async sp =>
Expand All @@ -48,7 +48,7 @@ public async Task Should_Exception_Be_Thrown_If_Different_TenantIds()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };
var testEntityOtherId = new TestEntity { Id = 2 };
await WithTenantScope(sp, testTenantId, async sp =>
Expand All @@ -75,7 +75,7 @@ public async Task Shoud_Apply_Filter()
var testTenantId2 = Guid.NewGuid();
var testEntity = new TestEntity { Id = 1 };
var testEntityOtherId = new TestEntity { Id = 2 };
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();

await WithTenantScope(sp, testTenantId1, async sp =>
{
Expand Down Expand Up @@ -109,7 +109,7 @@ public async Task Should_add_TenantId_and_filter_for_MultiTenantContext()
{
// arrange
var testTenantId = Guid.NewGuid();
var sp = GetServiceProvider<TestDbContext>(true);
var sp = GetServiceProvider<TestDbContext>();
var testEntity = new TestEntity { Id = 1 };
var testEntity1 = new TestEntity { Id = 2 };

Expand All @@ -133,13 +133,44 @@ await WithTenantScope(sp, testTenantId, async sp =>
});
}

private IServiceProvider GetServiceProvider<TDBContext>(bool isSharedDB) where TDBContext : DbContext
[Fact]
public async Task Can_Save_MultiTenantDbContext_WO_TennatContext_When_Only_NonMultiTenant_Entities_Changed()
{
// arrange
var sp = GetServiceProvider<TestDbContext>(DbStrategy.Shared);
var testEntity = new SimpleEntity { Id = 1 };
var testEntityOtherId = new SimpleEntity { Id = 2 };

var dbContext = sp.GetRequiredService<TestDbContext>();

dbContext.SimpleEntities.Add(testEntity);
dbContext.SimpleEntities.Add(testEntityOtherId);

// act
var count = await dbContext.SaveChangesAsync();

// assert
count.Should().Be(2);
}

enum DbStrategy
{
DatabasePerTenant,
Shared,
Hybrid
}

private IServiceProvider GetServiceProvider<TDBContext>(DbStrategy dbStrategy = DbStrategy.Hybrid) where TDBContext : DbContext
{
var tenantService = Mock.Of<ITenantContextAccessor>(x => x.TenantContext == null);
var isSharedDB = dbStrategy == DbStrategy.Shared;
var isHybridDB = dbStrategy == DbStrategy.Hybrid;
var connectionStringKey = isSharedDB ? "ConnectionStrings:myDb" : "MultiTenancy:Defaults:ConnectionStrings:myDb";
var connectionStringValue = isSharedDB || isHybridDB ? "Test" : Guid.NewGuid().ToString();
IConfiguration configuration = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
{
{ "MultiTenancy:Defaults:ConnectionStrings:myDb", isSharedDB ? "Test" : Guid.NewGuid().ToString()}
{ connectionStringKey, connectionStringValue }
})
.Build();

Expand All @@ -156,7 +187,9 @@ private IServiceProvider GetServiceProvider<TDBContext>(bool isSharedDB) where T
services.AddEntityFrameworkInMemoryDatabase()
.AddDbContext<TDBContext>((sp, options) =>
{
var conn = sp.GetRequiredService<ITenantConfiguration>().GetConnectionString("myDb");
var conn = isSharedDB ?
configuration.GetConnectionString("myDb") :
sp.GetRequiredService<ITenantConfiguration>().GetConnectionString("myDb");
options.UseInMemoryDatabase(conn).UseInternalServiceProvider(sp);
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
{
public class SimpleEntity
{
public int Id { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) TotalSoft.
// This source code is licensed under the MIT license.

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;

namespace NBB.Data.EntityFramework.MultiTenancy.Tests
{
public class SimpleEntityConfiguration : IEntityTypeConfiguration<SimpleEntity>
{
public void Configure(EntityTypeBuilder<SimpleEntity> builder)
{
builder.HasKey(x => x.Id);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace NBB.Data.EntityFramework.MultiTenancy.Tests
public class TestDbContext : MultiTenantDbContext
{
public DbSet<TestEntity> TestEntities { get; set; }
public DbSet<SimpleEntity> SimpleEntities { get; set; }

public TestDbContext(DbContextOptions<TestDbContext> options) : base(options)
{
Expand All @@ -16,6 +17,7 @@ public TestDbContext(DbContextOptions<TestDbContext> options) : base(options)
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.ApplyConfiguration(new TestEntityConfiguration());
modelBuilder.ApplyConfiguration(new SimpleEntityConfiguration());

base.OnModelCreating(modelBuilder);
}
Expand Down

0 comments on commit 9735286

Please sign in to comment.