From 4905640e6db8caa9c9b496da451108a727cc8ff7 Mon Sep 17 00:00:00 2001 From: Aegis Stack Date: Mon, 16 Mar 2026 22:48:10 -0400 Subject: [PATCH] RBAC - 4 --- aegis/cli/interactive.py | 4 + aegis/constants.py | 3 +- aegis/core/migration_generator.py | 61 ++++++++++++ aegis/core/services.py | 1 + aegis/core/template_generator.py | 9 +- .../app/models/org.py.jinja | 73 ++++++++++++++ .../services/auth/membership_service.py.jinja | 96 +++++++++++++++++++ .../app/services/auth/org_service.py.jinja | 66 +++++++++++++ tests/core/test_auth_service_parser.py | 10 ++ tests/core/test_migration_generator.py | 93 ++++++++++++++++++ tests/core/test_template_generator.py | 21 ++++ 11 files changed, 432 insertions(+), 5 deletions(-) create mode 100644 aegis/templates/copier-aegis-project/{{ project_slug }}/app/models/org.py.jinja create mode 100644 aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/membership_service.py.jinja create mode 100644 aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/org_service.py.jinja diff --git a/aegis/cli/interactive.py b/aegis/cli/interactive.py index f126dae5..b34c9fed 100644 --- a/aegis/cli/interactive.py +++ b/aegis/cli/interactive.py @@ -638,6 +638,10 @@ def interactive_auth_service_config( title="With Roles - + role-based access control", value=AuthLevels.RBAC, ), + questionary.Choice( + title="With Organizations - + multi-tenant support", + value=AuthLevels.ORG, + ), ] result = questionary.select( diff --git a/aegis/constants.py b/aegis/constants.py index b5ab751e..918a79fd 100644 --- a/aegis/constants.py +++ b/aegis/constants.py @@ -102,8 +102,9 @@ class AuthLevels: BASIC = "basic" RBAC = "rbac" + ORG = "org" - ALL = [BASIC, RBAC] + ALL = [BASIC, RBAC, ORG] class AnswerKeys: diff --git a/aegis/core/migration_generator.py b/aegis/core/migration_generator.py index bf962788..53e2faa6 100644 --- a/aegis/core/migration_generator.py +++ b/aegis/core/migration_generator.py @@ -101,6 +101,55 @@ class ServiceMigrationSpec: ], ) +ORG_MIGRATION = ServiceMigrationSpec( + service_name="auth_org", + description="Organization and membership tables", + tables=[ + TableSpec( + name="organization", + columns=[ + ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True), + ColumnSpec("name", "sa.String()", nullable=False), + ColumnSpec("slug", "sa.String()", nullable=False), + ColumnSpec("description", "sa.String()", nullable=True), + ColumnSpec("is_active", "sa.Boolean()", nullable=False, default="True"), + ColumnSpec("created_at", "sa.DateTime()", nullable=False), + ColumnSpec("updated_at", "sa.DateTime()", nullable=True), + ], + indexes=[IndexSpec("ix_organization_slug", ["slug"], unique=True)], + ), + TableSpec( + name="organization_member", + columns=[ + ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True), + ColumnSpec("organization_id", "sa.Integer()", nullable=False), + ColumnSpec("user_id", "sa.Integer()", nullable=False), + ColumnSpec("role", "sa.String()", nullable=False, default="'member'"), + ColumnSpec("joined_at", "sa.DateTime()", nullable=False), + ], + indexes=[ + IndexSpec( + "ix_org_member_org_user", + ["organization_id", "user_id"], + unique=True, + ), + IndexSpec( + "ix_org_member_organization_id", + ["organization_id"], + ), + IndexSpec( + "ix_org_member_user_id", + ["user_id"], + ), + ], + foreign_keys=[ + ForeignKeySpec(["organization_id"], "organization", ["id"]), + ForeignKeySpec(["user_id"], "user", ["id"]), + ], + ), + ], +) + AI_MIGRATION = ServiceMigrationSpec( service_name="ai", description="AI service tables (LLM catalog, usage tracking, conversations)", @@ -369,6 +418,7 @@ class ServiceMigrationSpec: # Registry of all service migrations MIGRATION_SPECS: dict[str, ServiceMigrationSpec] = { "auth": AUTH_MIGRATION, + "auth_org": ORG_MIGRATION, "ai": AI_MIGRATION, "ai_voice": VOICE_MIGRATION, } @@ -648,6 +698,17 @@ def get_services_needing_migrations(context: dict[str, Any]) -> list[str]: if include_auth == "yes" or include_auth is True: services.append("auth") + # Auth org tables (only with org-level auth) + include_auth_org = context.get("include_auth_org") + auth_level = context.get("auth_level") + org_enabled = ( + include_auth_org == "yes" + or include_auth_org is True + or (isinstance(auth_level, str) and auth_level.lower() == "org") + ) + if (include_auth == "yes" or include_auth is True) and org_enabled: + services.append("auth_org") + # AI service (only with persistence backend) include_ai = context.get("include_ai") ai_backend = context.get("ai_backend", "memory") diff --git a/aegis/core/services.py b/aegis/core/services.py index c11de2d3..68741678 100644 --- a/aegis/core/services.py +++ b/aegis/core/services.py @@ -58,6 +58,7 @@ class ServiceSpec: template_files=[ "app/components/backend/api/auth/", "app/models/user.py", + "app/models/org.py", "app/services/auth/", "app/core/security.py", ], diff --git a/aegis/core/template_generator.py b/aegis/core/template_generator.py index 4221c6e1..9b70d871 100644 --- a/aegis/core/template_generator.py +++ b/aegis/core/template_generator.py @@ -212,13 +212,14 @@ def get_template_context(self) -> dict[str, Any]: for s in self.selected_services ) else "no", - # Auth level selection (basic or rbac) - AnswerKeys.AUTH_LEVEL: self._get_auth_level(), + # Auth level selection (basic, rbac, or org) + AnswerKeys.AUTH_LEVEL: (auth_level := self._get_auth_level()), # Derived auth level flags for template conditionals + # Org level implies RBAC (org gets both roles and orgs) AnswerKeys.AUTH_RBAC: "yes" - if self._get_auth_level() == AuthLevels.RBAC + if auth_level in (AuthLevels.RBAC, AuthLevels.ORG) else "no", - AnswerKeys.AUTH_ORG: "no", # Reserved for future org-level auth + AnswerKeys.AUTH_ORG: "yes" if auth_level == AuthLevels.ORG else "no", AnswerKeys.AI: "yes" if any( extract_base_service_name(s) == AnswerKeys.SERVICE_AI diff --git a/aegis/templates/copier-aegis-project/{{ project_slug }}/app/models/org.py.jinja b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/models/org.py.jinja new file mode 100644 index 00000000..32019c4d --- /dev/null +++ b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/models/org.py.jinja @@ -0,0 +1,73 @@ +"""Organization and membership models.""" +{% if include_auth_org %} + +from datetime import UTC, datetime + +from sqlmodel import Field, SQLModel + + +# Org membership role constants +ORG_ROLE_OWNER = "owner" +ORG_ROLE_ADMIN = "admin" +ORG_ROLE_MEMBER = "member" +VALID_ORG_ROLES = {ORG_ROLE_OWNER, ORG_ROLE_ADMIN, ORG_ROLE_MEMBER} + + +class OrganizationBase(SQLModel): + """Base organization model with shared fields.""" + + name: str = Field(index=True) + slug: str = Field(unique=True, index=True) + description: str | None = None + is_active: bool = Field(default=True) + + +class Organization(OrganizationBase, table=True): + """Organization database model.""" + + id: int | None = Field(default=None, primary_key=True) + created_at: datetime = Field( + default_factory=lambda: datetime.now(UTC).replace(tzinfo=None) + ) + updated_at: datetime | None = None + + +class OrganizationMember(SQLModel, table=True): + """Organization membership database model.""" + + __tablename__ = "organization_member" + + id: int | None = Field(default=None, primary_key=True) + organization_id: int = Field(foreign_key="organization.id", index=True) + user_id: int = Field(foreign_key="user.id", index=True) + role: str = Field(default=ORG_ROLE_MEMBER) + joined_at: datetime = Field( + default_factory=lambda: datetime.now(UTC).replace(tzinfo=None) + ) + + +class OrgCreate(OrganizationBase): + """Organization creation model.""" + + pass + + +class OrgResponse(OrganizationBase): + """Organization response model.""" + + id: int + created_at: datetime + updated_at: datetime | None = None + + +class MemberResponse(SQLModel): + """Organization member response model.""" + + id: int + organization_id: int + user_id: int + role: str + joined_at: datetime +{% else %} +# Organization models not included (auth_level != org) +{% endif %} diff --git a/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/membership_service.py.jinja b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/membership_service.py.jinja new file mode 100644 index 00000000..d0a9ae07 --- /dev/null +++ b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/membership_service.py.jinja @@ -0,0 +1,96 @@ +"""Membership service for managing organization members.""" +{% if include_auth_org %} + +from datetime import UTC, datetime + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.org import ( + VALID_ORG_ROLES, + Organization, + OrganizationMember, +) + + +class MembershipService: + """Service for managing organization memberships.""" + + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def add_member( + self, org_id: int, user_id: int, role: str = "member" + ) -> OrganizationMember: + """Add a user to an organization.""" + if role not in VALID_ORG_ROLES: + raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}") + member = OrganizationMember( + organization_id=org_id, + user_id=user_id, + role=role, + joined_at=datetime.now(UTC).replace(tzinfo=None), + ) + self.db.add(member) + await self.db.commit() + await self.db.refresh(member) + return member + + async def remove_member(self, org_id: int, user_id: int) -> bool: + """Remove a user from an organization.""" + member = await self.get_member(org_id, user_id) + if not member: + return False + await self.db.delete(member) + await self.db.commit() + return True + + async def get_member( + self, org_id: int, user_id: int + ) -> OrganizationMember | None: + """Get a specific membership.""" + statement = select(OrganizationMember).where( + OrganizationMember.organization_id == org_id, + OrganizationMember.user_id == user_id, + ) + result = await self.db.exec(statement) + return result.first() + + async def update_member_role( + self, org_id: int, user_id: int, role: str + ) -> OrganizationMember | None: + """Update a member's role within an organization.""" + if role not in VALID_ORG_ROLES: + raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}") + member = await self.get_member(org_id, user_id) + if not member: + return None + member.role = role + self.db.add(member) + await self.db.commit() + await self.db.refresh(member) + return member + + async def list_org_members(self, org_id: int) -> list[OrganizationMember]: + """List all members of an organization.""" + statement = select(OrganizationMember).where( + OrganizationMember.organization_id == org_id + ) + result = await self.db.exec(statement) + return list(result.all()) + + async def list_user_orgs(self, user_id: int) -> list[Organization]: + """List all organizations a user belongs to.""" + statement = ( + select(Organization) + .join( + OrganizationMember, + OrganizationMember.organization_id == Organization.id, + ) + .where(OrganizationMember.user_id == user_id) + ) + result = await self.db.exec(statement) + return list(result.all()) +{% else %} +# Membership service not included (auth_level != org) +{% endif %} diff --git a/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/org_service.py.jinja b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/org_service.py.jinja new file mode 100644 index 00000000..517756bc --- /dev/null +++ b/aegis/templates/copier-aegis-project/{{ project_slug }}/app/services/auth/org_service.py.jinja @@ -0,0 +1,66 @@ +"""Organization service for CRUD operations.""" +{% if include_auth_org %} + +from datetime import UTC, datetime + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.org import OrgCreate, Organization + + +class OrgService: + """Service for managing organizations.""" + + def __init__(self, db: AsyncSession) -> None: + self.db = db + + async def create_org(self, org_data: OrgCreate) -> Organization: + """Create a new organization.""" + org = Organization.model_validate(org_data) + self.db.add(org) + await self.db.commit() + await self.db.refresh(org) + return org + + async def get_org_by_id(self, org_id: int) -> Organization | None: + """Get an organization by ID.""" + return await self.db.get(Organization, org_id) + + async def get_org_by_slug(self, slug: str) -> Organization | None: + """Get an organization by slug.""" + statement = select(Organization).where(Organization.slug == slug) + result = await self.db.exec(statement) + return result.first() + + async def update_org(self, org_id: int, **updates: str) -> Organization | None: + """Update an organization's fields.""" + org = await self.get_org_by_id(org_id) + if not org: + return None + for field, value in updates.items(): + if hasattr(org, field): + setattr(org, field, value) + org.updated_at = datetime.now(UTC).replace(tzinfo=None) + self.db.add(org) + await self.db.commit() + await self.db.refresh(org) + return org + + async def delete_org(self, org_id: int) -> bool: + """Delete an organization.""" + org = await self.get_org_by_id(org_id) + if not org: + return False + await self.db.delete(org) + await self.db.commit() + return True + + async def list_orgs(self) -> list[Organization]: + """List all organizations.""" + statement = select(Organization).order_by(Organization.created_at.desc()) + result = await self.db.exec(statement) + return list(result.all()) +{% else %} +# Organization service not included (auth_level != org) +{% endif %} diff --git a/tests/core/test_auth_service_parser.py b/tests/core/test_auth_service_parser.py index d0c2e33d..52ba0b34 100644 --- a/tests/core/test_auth_service_parser.py +++ b/tests/core/test_auth_service_parser.py @@ -58,6 +58,16 @@ def test_level_case_insensitive_basic_uppercase(self) -> None: result = parse_auth_service_config("auth[BASIC]") assert result.level == "basic" + def test_org_level(self) -> None: + """auth[org] → org""" + result = parse_auth_service_config("auth[org]") + assert result.level == "org" + + def test_org_level_case_insensitive(self) -> None: + """auth[ORG] → org (case insensitive)""" + result = parse_auth_service_config("auth[ORG]") + assert result.level == "org" + class TestAuthServiceParserWhitespace: """Test whitespace handling.""" diff --git a/tests/core/test_migration_generator.py b/tests/core/test_migration_generator.py index 72559480..aed3b5eb 100644 --- a/tests/core/test_migration_generator.py +++ b/tests/core/test_migration_generator.py @@ -11,6 +11,7 @@ AI_MIGRATION, AUTH_MIGRATION, MIGRATION_SPECS, + ORG_MIGRATION, VOICE_MIGRATION, ColumnSpec, IndexSpec, @@ -65,6 +66,50 @@ def test_neither_service(self) -> None: result = get_services_needing_migrations(context) assert result == [] + def test_auth_org_needs_migration(self) -> None: + """Test auth_org service needs migration when org level enabled.""" + context = { + "include_auth": "yes", + "include_auth_org": "yes", + "include_ai": False, + "ai_backend": "memory", + } + result = get_services_needing_migrations(context) + assert "auth_org" in result + + def test_auth_org_not_needed_without_org(self) -> None: + """Test auth_org service not needed when org level disabled.""" + context = { + "include_auth": "yes", + "include_auth_org": "no", + "include_ai": False, + "ai_backend": "memory", + } + result = get_services_needing_migrations(context) + assert "auth_org" not in result + + def test_auth_org_needs_migration_via_auth_level(self) -> None: + """Test auth_org detected via auth_level fallback when include_auth_org missing.""" + context = { + "include_auth": "yes", + "auth_level": "org", + "include_ai": False, + "ai_backend": "memory", + } + result = get_services_needing_migrations(context) + assert "auth_org" in result + + def test_auth_org_not_needed_without_auth(self) -> None: + """Test auth_org service not needed when auth not included.""" + context = { + "include_auth": False, + "include_auth_org": "yes", + "include_ai": False, + "ai_backend": "memory", + } + result = get_services_needing_migrations(context) + assert "auth_org" not in result + class TestGetVersionsDir: """Test getting the alembic versions directory.""" @@ -315,6 +360,54 @@ def test_ai_has_foreign_key(self) -> None: assert message_table.foreign_keys[0].ref_table == "conversation" +class TestOrgMigrationSpec: + """Test organization migration specification.""" + + def test_org_spec_exists(self) -> None: + """Test org migration spec is defined in MIGRATION_SPECS.""" + assert "auth_org" in MIGRATION_SPECS + assert ORG_MIGRATION.service_name == "auth_org" + + def test_org_has_two_tables(self) -> None: + """Organization migration should have two tables.""" + assert len(ORG_MIGRATION.tables) == 2 + + def test_organization_table_columns(self) -> None: + """Organization table should have expected columns.""" + org_table = next(t for t in ORG_MIGRATION.tables if t.name == "organization") + column_names = [col.name for col in org_table.columns] + + assert "name" in column_names + assert "slug" in column_names + assert "description" in column_names + assert "is_active" in column_names + assert "created_at" in column_names + assert "updated_at" in column_names + + def test_org_member_table_columns(self) -> None: + """Organization member table should have expected columns.""" + member_table = next( + t for t in ORG_MIGRATION.tables if t.name == "organization_member" + ) + column_names = [col.name for col in member_table.columns] + + assert "organization_id" in column_names + assert "user_id" in column_names + assert "role" in column_names + assert "joined_at" in column_names + + def test_org_member_foreign_keys(self) -> None: + """Organization member table should have foreign keys to org and user.""" + member_table = next( + t for t in ORG_MIGRATION.tables if t.name == "organization_member" + ) + assert len(member_table.foreign_keys) == 2 + + ref_tables = {fk.ref_table for fk in member_table.foreign_keys} + assert "organization" in ref_tables + assert "user" in ref_tables + + class TestDataclasses: """Test dataclass definitions.""" diff --git a/tests/core/test_template_generator.py b/tests/core/test_template_generator.py index 6ea978bb..48822f35 100644 --- a/tests/core/test_template_generator.py +++ b/tests/core/test_template_generator.py @@ -456,3 +456,24 @@ def test_auth_empty_brackets_defaults_to_basic(self) -> None: assert gen.auth_level == AuthLevels.BASIC context = gen.get_template_context() assert context["include_auth_rbac"] == "no" + + def test_auth_org_bracket_syntax(self) -> None: + """Auth service with auth[org] syntax should set org level.""" + gen = TemplateGenerator( + project_name="test", + selected_components=[], + selected_services=["auth[org]"], + ) + assert gen.auth_level == AuthLevels.ORG + + def test_context_auth_org_implies_rbac(self) -> None: + """Template context with org level should have both org and rbac flags.""" + gen = TemplateGenerator( + project_name="test", + selected_components=[], + selected_services=["auth[org]"], + ) + context = gen.get_template_context() + assert context["auth_level"] == AuthLevels.ORG + assert context["include_auth_org"] == "yes" + assert context["include_auth_rbac"] == "yes"