Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aegis/cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ def interactive_auth_service_config(
title="With Roles - + role-based access control",
value=AuthLevels.RBAC,
),
questionary.Choice(
title="With Organizations - + multi-tenant support",
value=AuthLevels.ORG,
),
]

result = questionary.select(
Expand Down
3 changes: 2 additions & 1 deletion aegis/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ class AuthLevels:

BASIC = "basic"
RBAC = "rbac"
ORG = "org"

ALL = [BASIC, RBAC]
ALL = [BASIC, RBAC, ORG]


class AnswerKeys:
Expand Down
61 changes: 61 additions & 0 deletions aegis/core/migration_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,55 @@ class ServiceMigrationSpec:
],
)

ORG_MIGRATION = ServiceMigrationSpec(
service_name="auth_org",
description="Organization and membership tables",
tables=[
TableSpec(
name="organization",
columns=[
ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True),
ColumnSpec("name", "sa.String()", nullable=False),
ColumnSpec("slug", "sa.String()", nullable=False),
ColumnSpec("description", "sa.String()", nullable=True),
ColumnSpec("is_active", "sa.Boolean()", nullable=False, default="True"),
ColumnSpec("created_at", "sa.DateTime()", nullable=False),
ColumnSpec("updated_at", "sa.DateTime()", nullable=True),
],
indexes=[IndexSpec("ix_organization_slug", ["slug"], unique=True)],
),
TableSpec(
name="organization_member",
columns=[
ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True),
ColumnSpec("organization_id", "sa.Integer()", nullable=False),
ColumnSpec("user_id", "sa.Integer()", nullable=False),
ColumnSpec("role", "sa.String()", nullable=False, default="'member'"),
ColumnSpec("joined_at", "sa.DateTime()", nullable=False),
],
indexes=[
IndexSpec(
"ix_org_member_org_user",
["organization_id", "user_id"],
unique=True,
),
IndexSpec(
"ix_org_member_organization_id",
["organization_id"],
),
IndexSpec(
"ix_org_member_user_id",
["user_id"],
),
],
foreign_keys=[
ForeignKeySpec(["organization_id"], "organization", ["id"]),
ForeignKeySpec(["user_id"], "user", ["id"]),
],
),
],
)

AI_MIGRATION = ServiceMigrationSpec(
service_name="ai",
description="AI service tables (LLM catalog, usage tracking, conversations)",
Expand Down Expand Up @@ -369,6 +418,7 @@ class ServiceMigrationSpec:
# Registry of all service migrations
MIGRATION_SPECS: dict[str, ServiceMigrationSpec] = {
"auth": AUTH_MIGRATION,
"auth_org": ORG_MIGRATION,
"ai": AI_MIGRATION,
"ai_voice": VOICE_MIGRATION,
}
Expand Down Expand Up @@ -648,6 +698,17 @@ def get_services_needing_migrations(context: dict[str, Any]) -> list[str]:
if include_auth == "yes" or include_auth is True:
services.append("auth")

# Auth org tables (only with org-level auth)
include_auth_org = context.get("include_auth_org")
auth_level = context.get("auth_level")
org_enabled = (
include_auth_org == "yes"
or include_auth_org is True
or (isinstance(auth_level, str) and auth_level.lower() == "org")
)
if (include_auth == "yes" or include_auth is True) and org_enabled:
services.append("auth_org")

# AI service (only with persistence backend)
include_ai = context.get("include_ai")
ai_backend = context.get("ai_backend", "memory")
Expand Down
1 change: 1 addition & 0 deletions aegis/core/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ServiceSpec:
template_files=[
"app/components/backend/api/auth/",
"app/models/user.py",
"app/models/org.py",
"app/services/auth/",
"app/core/security.py",
],
Expand Down
9 changes: 5 additions & 4 deletions aegis/core/template_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,14 @@ def get_template_context(self) -> dict[str, Any]:
for s in self.selected_services
)
else "no",
# Auth level selection (basic or rbac)
AnswerKeys.AUTH_LEVEL: self._get_auth_level(),
# Auth level selection (basic, rbac, or org)
AnswerKeys.AUTH_LEVEL: (auth_level := self._get_auth_level()),
# Derived auth level flags for template conditionals
# Org level implies RBAC (org gets both roles and orgs)
AnswerKeys.AUTH_RBAC: "yes"
if self._get_auth_level() == AuthLevels.RBAC
if auth_level in (AuthLevels.RBAC, AuthLevels.ORG)
else "no",
AnswerKeys.AUTH_ORG: "no", # Reserved for future org-level auth
AnswerKeys.AUTH_ORG: "yes" if auth_level == AuthLevels.ORG else "no",
AnswerKeys.AI: "yes"
if any(
extract_base_service_name(s) == AnswerKeys.SERVICE_AI
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Organization and membership models."""
{% if include_auth_org %}

from datetime import UTC, datetime

from sqlmodel import Field, SQLModel


# Org membership role constants
ORG_ROLE_OWNER = "owner"
ORG_ROLE_ADMIN = "admin"
ORG_ROLE_MEMBER = "member"
VALID_ORG_ROLES = {ORG_ROLE_OWNER, ORG_ROLE_ADMIN, ORG_ROLE_MEMBER}


class OrganizationBase(SQLModel):
"""Base organization model with shared fields."""

name: str = Field(index=True)
slug: str = Field(unique=True, index=True)
description: str | None = None
is_active: bool = Field(default=True)


class Organization(OrganizationBase, table=True):
"""Organization database model."""

id: int | None = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)
)
updated_at: datetime | None = None


class OrganizationMember(SQLModel, table=True):
"""Organization membership database model."""

__tablename__ = "organization_member"

id: int | None = Field(default=None, primary_key=True)
organization_id: int = Field(foreign_key="organization.id", index=True)
user_id: int = Field(foreign_key="user.id", index=True)
role: str = Field(default=ORG_ROLE_MEMBER)
joined_at: datetime = Field(
default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)
)


class OrgCreate(OrganizationBase):
"""Organization creation model."""

pass


class OrgResponse(OrganizationBase):
"""Organization response model."""

id: int
created_at: datetime
updated_at: datetime | None = None


class MemberResponse(SQLModel):
"""Organization member response model."""

id: int
organization_id: int
user_id: int
role: str
joined_at: datetime
{% else %}
# Organization models not included (auth_level != org)
{% endif %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Membership service for managing organization members."""
{% if include_auth_org %}

from datetime import UTC, datetime

from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.models.org import (
VALID_ORG_ROLES,
Organization,
OrganizationMember,
)


class MembershipService:
"""Service for managing organization memberships."""

def __init__(self, db: AsyncSession) -> None:
self.db = db

async def add_member(
self, org_id: int, user_id: int, role: str = "member"
) -> OrganizationMember:
"""Add a user to an organization."""
if role not in VALID_ORG_ROLES:
raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}")
member = OrganizationMember(
organization_id=org_id,
user_id=user_id,
role=role,
joined_at=datetime.now(UTC).replace(tzinfo=None),
)
self.db.add(member)
await self.db.commit()
await self.db.refresh(member)
return member

async def remove_member(self, org_id: int, user_id: int) -> bool:
"""Remove a user from an organization."""
member = await self.get_member(org_id, user_id)
if not member:
return False
await self.db.delete(member)
await self.db.commit()
return True

async def get_member(
self, org_id: int, user_id: int
) -> OrganizationMember | None:
"""Get a specific membership."""
statement = select(OrganizationMember).where(
OrganizationMember.organization_id == org_id,
OrganizationMember.user_id == user_id,
)
result = await self.db.exec(statement)
return result.first()

async def update_member_role(
self, org_id: int, user_id: int, role: str
) -> OrganizationMember | None:
"""Update a member's role within an organization."""
if role not in VALID_ORG_ROLES:
raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}")
member = await self.get_member(org_id, user_id)
if not member:
return None
member.role = role
self.db.add(member)
await self.db.commit()
await self.db.refresh(member)
return member

async def list_org_members(self, org_id: int) -> list[OrganizationMember]:
"""List all members of an organization."""
statement = select(OrganizationMember).where(
OrganizationMember.organization_id == org_id
)
result = await self.db.exec(statement)
return list(result.all())

async def list_user_orgs(self, user_id: int) -> list[Organization]:
"""List all organizations a user belongs to."""
statement = (
select(Organization)
.join(
OrganizationMember,
OrganizationMember.organization_id == Organization.id,
)
.where(OrganizationMember.user_id == user_id)
)
result = await self.db.exec(statement)
return list(result.all())
{% else %}
# Membership service not included (auth_level != org)
{% endif %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Organization service for CRUD operations."""
{% if include_auth_org %}

from datetime import UTC, datetime

from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.models.org import OrgCreate, Organization


class OrgService:
"""Service for managing organizations."""

def __init__(self, db: AsyncSession) -> None:
self.db = db

async def create_org(self, org_data: OrgCreate) -> Organization:
"""Create a new organization."""
org = Organization.model_validate(org_data)
self.db.add(org)
await self.db.commit()
await self.db.refresh(org)
return org

async def get_org_by_id(self, org_id: int) -> Organization | None:
"""Get an organization by ID."""
return await self.db.get(Organization, org_id)

async def get_org_by_slug(self, slug: str) -> Organization | None:
"""Get an organization by slug."""
statement = select(Organization).where(Organization.slug == slug)
result = await self.db.exec(statement)
return result.first()

async def update_org(self, org_id: int, **updates: str) -> Organization | None:
"""Update an organization's fields."""
org = await self.get_org_by_id(org_id)
if not org:
return None
for field, value in updates.items():
if hasattr(org, field):
setattr(org, field, value)
org.updated_at = datetime.now(UTC).replace(tzinfo=None)
self.db.add(org)
await self.db.commit()
await self.db.refresh(org)
return org

async def delete_org(self, org_id: int) -> bool:
"""Delete an organization."""
org = await self.get_org_by_id(org_id)
if not org:
return False
await self.db.delete(org)
await self.db.commit()
return True

async def list_orgs(self) -> list[Organization]:
"""List all organizations."""
statement = select(Organization).order_by(Organization.created_at.desc())
result = await self.db.exec(statement)
return list(result.all())
{% else %}
# Organization service not included (auth_level != org)
{% endif %}
10 changes: 10 additions & 0 deletions tests/core/test_auth_service_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def test_level_case_insensitive_basic_uppercase(self) -> None:
result = parse_auth_service_config("auth[BASIC]")
assert result.level == "basic"

def test_org_level(self) -> None:
"""auth[org] → org"""
result = parse_auth_service_config("auth[org]")
assert result.level == "org"

def test_org_level_case_insensitive(self) -> None:
"""auth[ORG] → org (case insensitive)"""
result = parse_auth_service_config("auth[ORG]")
assert result.level == "org"


class TestAuthServiceParserWhitespace:
"""Test whitespace handling."""
Expand Down
Loading
Loading