Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions backend/src/services/dataset/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
StorageError,
ValidationError,
)
from src.domain.adapters import CSVSchemaDetector
from src.repo.dataset import DatasetRepo
from src.schemas.db import Datasets
from src.services.storage import storage
Expand Down Expand Up @@ -271,6 +272,120 @@ async def get_dataset_files(

return files_data

async def drop_zero_enrollment(
self, dataset_id: UUID, user_id: UUID
) -> dict[str, pd.DataFrame]:
"""
Return dataset files with zero-enrollment courses removed.

Args:
dataset_id: Dataset ID
user_id: User ID for authorization

Returns:
Dictionary with filtered courses, enrollments, and rooms dataframes
"""
files = await self.get_dataset_files(dataset_id, user_id)

courses_df = files["courses"]
enrollments_df = files["enrollments"]

filtered_courses_df, allowed_crns = self._filter_nonzero_enrollment(
courses_df
)

# If we couldn't determine CRNs/columns, keep enrollments as-is.
filtered_enrollments_df = (
self._filter_by_allowed_crns(enrollments_df, allowed_crns)
if allowed_crns is not None
else enrollments_df
)

return {
"courses": filtered_courses_df,
"enrollments": filtered_enrollments_df,
"rooms": files["rooms"],
}

def _filter_nonzero_enrollment(
self, courses_df: pd.DataFrame
) -> tuple[pd.DataFrame, set[str] | None]:
"""
Filter the courses DataFrame to remove rows where Total_Enrollment == 0.

Returns:
(filtered_df, allowed_crns)
"""
try:
schema, column_mapping = CSVSchemaDetector.detect_schema_version(
courses_df, "courses"
)
except Exception:
# If schema detection fails, don't change behavior.
return courses_df.copy(), None

canonical_to_csv = {canonical: csv for csv, canonical in column_mapping.items()}
enrollment_col = canonical_to_csv.get("Total_Enrollment")
crn_col = canonical_to_csv.get("Course_Reference_Number")
if not enrollment_col or not crn_col:
return courses_df.copy(), None

col_defs = {cd.canonical_name: cd for cd in schema}
enrollment_transformer = col_defs.get("Total_Enrollment").transformer if col_defs.get("Total_Enrollment") else None
crn_transformer = col_defs.get("Course_Reference_Number").transformer if col_defs.get("Course_Reference_Number") else None

enrollment_series = courses_df[enrollment_col]
if enrollment_transformer:
enrollment_series = enrollment_series.apply(enrollment_transformer)

# Keep only nonzero enrollments; treat None/NaN as zero for this filter.
try:
nonzero_mask = enrollment_series.fillna(0).astype(int) != 0
except Exception:
nonzero_mask = enrollment_series.fillna(0) != 0

filtered_df = courses_df.loc[nonzero_mask].copy()

crn_series = filtered_df[crn_col]
if crn_transformer:
crn_series = crn_series.apply(crn_transformer)

allowed_crns = {crn for crn in crn_series.tolist() if crn}
return filtered_df, allowed_crns

def _filter_by_allowed_crns(
self, enrollments_df: pd.DataFrame, allowed_crns: set[str]
) -> pd.DataFrame:
"""
Filter enrollments to only those whose CRN is in allowed_crns.

This keeps enrollments consistent with a temporarily filtered course list.
"""
if not allowed_crns:
return enrollments_df.copy()

try:
schema, column_mapping = CSVSchemaDetector.detect_schema_version(
enrollments_df, "enrollments"
)
except Exception:
return enrollments_df.copy()

canonical_to_csv = {canonical: csv for csv, canonical in column_mapping.items()}
crn_col = canonical_to_csv.get("Course_Reference_Number")
if not crn_col:
return enrollments_df.copy()

col_defs = {cd.canonical_name: cd for cd in schema}
crn_transformer = col_defs.get("Course_Reference_Number").transformer if col_defs.get("Course_Reference_Number") else None

crn_series = enrollments_df[crn_col]
if crn_transformer:
crn_series = crn_series.apply(crn_transformer)

mask = crn_series.isin(allowed_crns)
return enrollments_df.loc[mask].copy()

async def _download_and_parse(self, file_entry: dict) -> tuple[str, pd.DataFrame]:
"""Download one file and parse it."""
file_type = file_entry["type"]
Expand Down
3 changes: 3 additions & 0 deletions backend/src/services/schedule/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ async def generate_schedule(
# 3. Load course merges (if any) - synchronous call
merges = self.dataset_service.get_merges(dataset_id, user_id) or {}

# 3.5 Drop zero-enrollment courses and get updated merges
merges = self.dataset_service.drop_zero_enrollment(dataset_id, user_id)

# 4. Build scheduling dataset and run algorithm
scheduling_dataset = DatasetFactory.from_dataframes_to_scheduling_dataset(
courses_df=files["courses"],
Expand Down