From 98171c48ce00970463e5533a8061344efd206c1d Mon Sep 17 00:00:00 2001 From: Steven Tsang <3403544+tsangste@users.noreply.github.com> Date: Thu, 23 Jan 2025 21:26:19 +0000 Subject: [PATCH] fix: resolve issue with ordering of middleware being applied (#189) When setting up the nestjs modules with MikroORM there is a chance it can throw Using global EntityManager instance methods for context specific actions is disallowed. when interacting with the EM within another middleware. This fix applies to both single and multi database set-ups. --- README.md | 2 +- package.json | 12 +-- src/mikro-orm-core.module.ts | 14 +++- src/mikro-orm-middleware.module.ts | 9 +-- src/mikro-orm.module.ts | 23 ++---- tests/entities/foo.entity.ts | 3 +- tests/mikro-orm.middleware.test.ts | 84 ++++++++++++++------ tests/mikro-orm.module-middleware.test.ts | 94 +++++++++++++++++++++++ yarn.lock | 44 +++++------ 9 files changed, 208 insertions(+), 77 deletions(-) create mode 100644 tests/mikro-orm.module-middleware.test.ts diff --git a/README.md b/README.md index c46bf8f..a0d3d63 100644 --- a/README.md +++ b/README.md @@ -342,7 +342,7 @@ More information about [enableShutdownHooks](https://docs.nestjs.com/fundamental ## Multiple Database Connections -You can define multiple database connections by registering multiple `MikroOrmModule` and setting their `contextName`. If you want to use middleware request context you must disable automatic middleware and register `MikroOrmModule` with `forMiddleware()` or use NestJS `Injection Scope` +You can define multiple database connections by registering multiple `MikroOrmModule`'s, each with a unique `contextName`. You will need to disable the automatic request context middleware by setting `registerRequestContext` to `false`, as it wouldn't work with this approach - note that this needs to be part of all your `MikroOrmModule`s with non-default `contextName`. To have the same automatic request context behaviour, you must register `MikroOrmModule` with `forMiddleware()` instead: ```typescript @Module({ diff --git a/package.json b/package.json index 271d0d6..9b42011 100644 --- a/package.json +++ b/package.json @@ -41,16 +41,16 @@ }, "peerDependencies": { "@mikro-orm/core": "^6.0.0 || ^6.0.0-dev.0", - "@nestjs/common": "^10.0.0 || ^11.0.0", - "@nestjs/core": "^10.0.0 || ^11.0.0" + "@nestjs/common": "^10.0.0 || ^11.0.5", + "@nestjs/core": "^10.0.0 || ^11.0.5" }, "devDependencies": { "@mikro-orm/core": "^6.2.7", "@mikro-orm/sqlite": "^6.2.7", - "@nestjs/common": "^11.0.0", - "@nestjs/core": "^11.0.0", - "@nestjs/platform-express": "^11.0.0", - "@nestjs/testing": "^11.0.0", + "@nestjs/common": "^11.0.5", + "@nestjs/core": "^11.0.5", + "@nestjs/platform-express": "^11.0.5", + "@nestjs/testing": "^11.0.5", "@types/jest": "^29.5.12", "@types/node": "^22.0.0", "@types/supertest": "^6.0.2", diff --git a/src/mikro-orm-core.module.ts b/src/mikro-orm-core.module.ts index a59568e..71e7f2a 100644 --- a/src/mikro-orm-core.module.ts +++ b/src/mikro-orm-core.module.ts @@ -1,5 +1,15 @@ import { Configuration, ConfigurationLoader, EntityManager, MikroORM, type Dictionary } from '@mikro-orm/core'; -import { Global, Inject, Module, RequestMethod, type DynamicModule, type MiddlewareConsumer, type OnApplicationShutdown, type Type } from '@nestjs/common'; +import { + Global, + Inject, + Module, + RequestMethod, + type DynamicModule, + type MiddlewareConsumer, + type NestModule, + type OnApplicationShutdown, + type Type, +} from '@nestjs/common'; import { ModuleRef } from '@nestjs/core'; import { forRoutesPath } from './middleware.helper'; @@ -31,7 +41,7 @@ const PACKAGES = { @Global() @Module({}) -export class MikroOrmCoreModule implements OnApplicationShutdown { +export class MikroOrmCoreModule implements NestModule, OnApplicationShutdown { constructor(@Inject(MIKRO_ORM_MODULE_OPTIONS) private readonly options: MikroOrmModuleOptions, diff --git a/src/mikro-orm-middleware.module.ts b/src/mikro-orm-middleware.module.ts index 52eac73..5bedd5f 100644 --- a/src/mikro-orm-middleware.module.ts +++ b/src/mikro-orm-middleware.module.ts @@ -1,4 +1,4 @@ -import { Global, Inject, Module, RequestMethod, type MiddlewareConsumer } from '@nestjs/common'; +import { Global, Inject, Module, RequestMethod, type MiddlewareConsumer, type NestModule } from '@nestjs/common'; import type { MikroORM } from '@mikro-orm/core'; import { forRoutesPath } from './middleware.helper'; @@ -8,15 +8,12 @@ import { MikroOrmMiddlewareModuleOptions } from './typings'; @Global() @Module({}) -export class MikroOrmMiddlewareModule { +export class MikroOrmMiddlewareModule implements NestModule { constructor(@Inject(MIKRO_ORM_MODULE_OPTIONS) private readonly options: MikroOrmMiddlewareModuleOptions) { } - static forMiddleware(options?: MikroOrmMiddlewareModuleOptions) { - // Work around due to nestjs not supporting the ability to register multiple types - // https://github.com/nestjs/nest/issues/770 - // https://github.com/nestjs/nest/issues/4786#issuecomment-755032258 - workaround suggestion + static forRoot(options?: MikroOrmMiddlewareModuleOptions) { const inject = CONTEXT_NAMES.map(name => getMikroORMToken(name)); return { module: MikroOrmMiddlewareModule, diff --git a/src/mikro-orm.module.ts b/src/mikro-orm.module.ts index b5949b3..ef8607c 100644 --- a/src/mikro-orm.module.ts +++ b/src/mikro-orm.module.ts @@ -4,12 +4,12 @@ import { MikroOrmCoreModule } from './mikro-orm-core.module'; import { MikroOrmMiddlewareModule } from './mikro-orm-middleware.module'; import { MikroOrmEntitiesStorage } from './mikro-orm.entities.storage'; import { createMikroOrmRepositoryProviders } from './mikro-orm.providers'; -import type { +import { EntityName, - MikroOrmMiddlewareModuleOptions, MikroOrmModuleAsyncOptions, MikroOrmModuleFeatureOptions, MikroOrmModuleSyncOptions, + MikroOrmMiddlewareModuleOptions, } from './typings'; @Module({}) @@ -23,18 +23,12 @@ export class MikroOrmModule { MikroOrmEntitiesStorage.clear(contextName); } - static forRoot(options?: MikroOrmModuleSyncOptions): DynamicModule { - return { - module: MikroOrmModule, - imports: [MikroOrmCoreModule.forRoot(options)], - }; + static forRoot(options?: MikroOrmModuleSyncOptions): DynamicModule | Promise { + return MikroOrmCoreModule.forRoot(options); } - static forRootAsync(options: MikroOrmModuleAsyncOptions): DynamicModule { - return { - module: MikroOrmModule, - imports: [MikroOrmCoreModule.forRootAsync(options)], - }; + static forRootAsync(options: MikroOrmModuleAsyncOptions): DynamicModule | Promise { + return MikroOrmCoreModule.forRootAsync(options); } static forFeature(options: EntityName[] | MikroOrmModuleFeatureOptions, contextName?: string): DynamicModule { @@ -56,10 +50,7 @@ export class MikroOrmModule { } static forMiddleware(options?: MikroOrmMiddlewareModuleOptions): DynamicModule { - return { - module: MikroOrmModule, - imports: [MikroOrmMiddlewareModule.forMiddleware(options)], - }; + return MikroOrmMiddlewareModule.forRoot(options); } } diff --git a/tests/entities/foo.entity.ts b/tests/entities/foo.entity.ts index b5aa3da..0e7b8fa 100644 --- a/tests/entities/foo.entity.ts +++ b/tests/entities/foo.entity.ts @@ -1,6 +1,7 @@ -import { PrimaryKey, Entity } from '@mikro-orm/core'; +import { PrimaryKey, Entity, Filter } from '@mikro-orm/core'; @Entity() +@Filter({ name: 'id', cond: args => ({ id: args.id }) }) export class Foo { @PrimaryKey() diff --git a/tests/mikro-orm.middleware.test.ts b/tests/mikro-orm.middleware.test.ts index 62264b2..07cb746 100644 --- a/tests/mikro-orm.middleware.test.ts +++ b/tests/mikro-orm.middleware.test.ts @@ -1,16 +1,18 @@ -import type { Options } from '@mikro-orm/core'; -import { MikroORM } from '@mikro-orm/core'; +import { EntityManager, MikroORM, type Options } from '@mikro-orm/core'; import { SqliteDriver } from '@mikro-orm/sqlite'; -import type { INestApplication } from '@nestjs/common'; import { Controller, Get, Module, + type INestApplication, + Injectable, + type MiddlewareConsumer, + type NestMiddleware, + type NestModule, } from '@nestjs/common'; -import type { TestingModule } from '@nestjs/testing'; -import { Test } from '@nestjs/testing'; +import { Test, type TestingModule } from '@nestjs/testing'; import request from 'supertest'; -import { InjectMikroORM, MikroOrmModule } from '../src'; +import { InjectEntityManager, InjectMikroORM, MikroOrmModule } from '../src'; import { Bar } from './entities/bar.entity'; import { Foo } from './entities/foo.entity'; @@ -21,54 +23,90 @@ const testOptions: Options = { entities: ['entities'], }; -@Controller() -class TestController { +@Controller('/foo') +class FooController { - constructor( - @InjectMikroORM('database1') private database1: MikroORM, - @InjectMikroORM('database2') private database2: MikroORM, - ) {} + constructor(@InjectMikroORM('database-multi-foo') private database1: MikroORM) {} - @Get('foo') + @Get() foo() { return this.database1.em !== this.database1.em.getContext(); } - @Get('bar') +} + +@Controller('/bar') +class BarController { + + constructor(@InjectMikroORM('database-multi-bar') private database2: MikroORM) {} + + @Get() bar() { return this.database2.em !== this.database2.em.getContext(); } } +@Injectable() +export class TestMiddleware implements NestMiddleware { + + constructor(@InjectEntityManager('database-multi-foo') private readonly em: EntityManager) {} + + use(req: unknown, res: unknown, next: (...args: any[]) => void) { + this.em.setFilterParams('id', { id: '1' }); + + return next(); + } + +} + +@Module({ + imports: [MikroOrmModule.forFeature([Foo], 'database-multi-foo')], + controllers: [FooController], +}) +class FooModule implements NestModule { + + configure(consumer: MiddlewareConsumer): void { + consumer + .apply(TestMiddleware) + .forRoutes('/'); + } + +} + +@Module({ + imports: [MikroOrmModule.forFeature([Bar], 'database-multi-bar')], + controllers: [BarController], +}) +class BarModule {} + @Module({ imports: [ MikroOrmModule.forRootAsync({ - contextName: 'database1', + contextName: 'database-multi-foo', useFactory: () => ({ registerRequestContext: false, ...testOptions, }), }), MikroOrmModule.forRoot({ - contextName: 'database2', + contextName: 'database-multi-bar', registerRequestContext: false, ...testOptions, }), MikroOrmModule.forMiddleware(), - MikroOrmModule.forFeature([Foo], 'database1'), - MikroOrmModule.forFeature([Bar], 'database2'), + FooModule, + BarModule, ], - controllers: [TestController], }) -class TestModule {} +class TestMultiModule {} -describe('Middleware executes request context for all MikroORM registered', () => { +describe('Multiple Middleware executes request context for all MikroORM registered', () => { let app: INestApplication; beforeAll(async () => { const moduleFixture: TestingModule = await Test.createTestingModule({ - imports: [TestModule], + imports: [TestMultiModule], }).compile(); app = moduleFixture.createNestApplication(); @@ -81,7 +119,7 @@ describe('Middleware executes request context for all MikroORM registered', () = }); it(`forRoutes(/bar) should return 'true'`, () => { - return request(app.getHttpServer()).get('/foo').expect(200, 'true'); + return request(app.getHttpServer()).get('/bar').expect(200, 'true'); }); afterAll(async () => { diff --git a/tests/mikro-orm.module-middleware.test.ts b/tests/mikro-orm.module-middleware.test.ts new file mode 100644 index 0000000..a53b8b2 --- /dev/null +++ b/tests/mikro-orm.module-middleware.test.ts @@ -0,0 +1,94 @@ +import { EntityManager, MikroORM, type Options } from '@mikro-orm/core'; +import { SqliteDriver } from '@mikro-orm/sqlite'; +import { + Controller, + Get, + Module, + type INestApplication, + Injectable, + type MiddlewareConsumer, + type NestMiddleware, + type NestModule, +} from '@nestjs/common'; +import { Test, type TestingModule } from '@nestjs/testing'; +import request from 'supertest'; +import { MikroOrmModule } from '../src'; +import { Foo } from './entities/foo.entity'; + +const testOptions: Options = { + dbName: ':memory:', + driver: SqliteDriver, + baseDir: __dirname, + entities: ['entities'], +}; + +@Controller('/foo') +class FooController { + + constructor(private database1: MikroORM) {} + + @Get() + foo() { + return this.database1.em !== this.database1.em.getContext(); + } + +} + +@Injectable() +export class TestMiddleware implements NestMiddleware { + + constructor(private readonly em: EntityManager) {} + + use(req: unknown, res: unknown, next: (...args: any[]) => void) { + this.em.setFilterParams('id', { id: '1' }); + + return next(); + } + +} + +@Module({ + imports: [MikroOrmModule.forFeature([Foo])], + controllers: [FooController], +}) +class FooModule implements NestModule { + + configure(consumer: MiddlewareConsumer): void { + consumer + .apply(TestMiddleware) + .forRoutes('/'); + } + +} + +@Module({ + imports: [ + MikroOrmModule.forRootAsync({ + useFactory: () => testOptions, + }), + FooModule, + ], +}) +class TestModule {} + +describe('Middleware executes request context', () => { + let app: INestApplication; + + beforeAll(async () => { + const moduleFixture: TestingModule = await Test.createTestingModule({ + imports: [TestModule], + }).compile(); + + app = moduleFixture.createNestApplication(); + + await app.init(); + }); + + it(`forRoutes(/foo) should return 'true'`, () => { + return request(app.getHttpServer()).get('/foo').expect(200, 'true'); + }); + + afterAll(async () => { + await app.close(); + }); +}); diff --git a/yarn.lock b/yarn.lock index 6edd94a..ad8d474 100644 --- a/yarn.lock +++ b/yarn.lock @@ -989,10 +989,10 @@ __metadata: dependencies: "@mikro-orm/core": "npm:^6.2.7" "@mikro-orm/sqlite": "npm:^6.2.7" - "@nestjs/common": "npm:^11.0.0" - "@nestjs/core": "npm:^11.0.0" - "@nestjs/platform-express": "npm:^11.0.0" - "@nestjs/testing": "npm:^11.0.0" + "@nestjs/common": "npm:^11.0.5" + "@nestjs/core": "npm:^11.0.5" + "@nestjs/platform-express": "npm:^11.0.5" + "@nestjs/testing": "npm:^11.0.5" "@types/jest": "npm:^29.5.12" "@types/node": "npm:^22.0.0" "@types/supertest": "npm:^6.0.2" @@ -1010,8 +1010,8 @@ __metadata: typescript: "npm:5.7.3" peerDependencies: "@mikro-orm/core": ^6.0.0 || ^6.0.0-dev.0 - "@nestjs/common": ^10.0.0 || ^11.0.0 - "@nestjs/core": ^10.0.0 || ^11.0.0 + "@nestjs/common": ^10.0.0 || ^11.0.5 + "@nestjs/core": ^10.0.0 || ^11.0.5 languageName: unknown linkType: soft @@ -1029,9 +1029,9 @@ __metadata: languageName: node linkType: hard -"@nestjs/common@npm:^11.0.0": - version: 11.0.3 - resolution: "@nestjs/common@npm:11.0.3" +"@nestjs/common@npm:^11.0.5": + version: 11.0.5 + resolution: "@nestjs/common@npm:11.0.5" dependencies: iterare: "npm:1.2.1" tslib: "npm:2.8.1" @@ -1046,13 +1046,13 @@ __metadata: optional: true class-validator: optional: true - checksum: 10c0/3f4948e6ac8c5e324f6d840f27d4170541676b8a293b2f68dd51db42483e90f39cf654e50a97545cf3aa3b1e59b6b9396e92b9b3c03884dc4503dd65dc5d916e + checksum: 10c0/e65066d076e2d36dcd728ab09499e313f29c7867c975057568522538b8537f2ba689b9f488aafe7bfa6cfdd6e33bc36d4023c725bbac681bb98b537bde2a2422 languageName: node linkType: hard -"@nestjs/core@npm:^11.0.0": - version: 11.0.3 - resolution: "@nestjs/core@npm:11.0.3" +"@nestjs/core@npm:^11.0.5": + version: 11.0.5 + resolution: "@nestjs/core@npm:11.0.5" dependencies: "@nuxt/opencollective": "npm:0.4.1" fast-safe-stringify: "npm:2.1.1" @@ -1074,13 +1074,13 @@ __metadata: optional: true "@nestjs/websockets": optional: true - checksum: 10c0/5d05b4f1a85b9909d70e78ad48e1c755e65d1c7e11a6eb1cd5f9a8cada092412134a0de2c2c83da4e51e7581531b455f013b56b0f9c6c2dbcc747e33a4f76af4 + checksum: 10c0/8e94ee276dc57844747941dc102962d9844b333412bec90abddef957604b2a7e7f25c338fd58573711d67eee003d5110d3a85bf353a56bfe6fb3fec8263c5961 languageName: node linkType: hard -"@nestjs/platform-express@npm:^11.0.0": - version: 11.0.3 - resolution: "@nestjs/platform-express@npm:11.0.3" +"@nestjs/platform-express@npm:^11.0.5": + version: 11.0.5 + resolution: "@nestjs/platform-express@npm:11.0.5" dependencies: cors: "npm:2.8.5" express: "npm:5.0.1" @@ -1090,13 +1090,13 @@ __metadata: peerDependencies: "@nestjs/common": ^11.0.0 "@nestjs/core": ^11.0.0 - checksum: 10c0/76f40c89c94ea0b69d8ccf099a7291d947885cec9653203a849da38bc27d3a5c2b4441ed1c3c1112e014296bd3922b400d9d8911e796fb04bc9e118c9185e7ae + checksum: 10c0/d33a60a5f71e5cc53dfa97490d30ba6200f56f51e6898a5d651f60a666515b8eff322a88339d918e1e6c57de075867b2d71218841cb4b43dea9a6ea4e221ecc6 languageName: node linkType: hard -"@nestjs/testing@npm:^11.0.0": - version: 11.0.3 - resolution: "@nestjs/testing@npm:11.0.3" +"@nestjs/testing@npm:^11.0.5": + version: 11.0.5 + resolution: "@nestjs/testing@npm:11.0.5" dependencies: tslib: "npm:2.8.1" peerDependencies: @@ -1109,7 +1109,7 @@ __metadata: optional: true "@nestjs/platform-express": optional: true - checksum: 10c0/36e9767148594c2a0c14eacfdff263cac25280e51357d5ba0365f073cafe3f8581864bc4501dac465ff05cfc7cef46774268f216c73ef9bdf2d4628f496d90d8 + checksum: 10c0/658f4da89d39c2dd8d74edc6e82aaacd0b778e2d440161bb714d312abd250ebd07936d9aa354b6a333e1d165dad9ae327bfe8d7e811b7fed49e32a903192d754 languageName: node linkType: hard