diff --git a/backend/src/services/dataset/service.py b/backend/src/services/dataset/service.py index 9fba906..a70387c 100644 --- a/backend/src/services/dataset/service.py +++ b/backend/src/services/dataset/service.py @@ -14,6 +14,7 @@ StorageError, ValidationError, ) +from src.domain.adapters import CSVSchemaDetector from src.repo.dataset import DatasetRepo from src.schemas.db import Datasets from src.services.storage import storage @@ -271,6 +272,120 @@ async def get_dataset_files( return files_data + async def drop_zero_enrollment( + self, dataset_id: UUID, user_id: UUID + ) -> dict[str, pd.DataFrame]: + """ + Return dataset files with zero-enrollment courses removed. + + Args: + dataset_id: Dataset ID + user_id: User ID for authorization + + Returns: + Dictionary with filtered courses, enrollments, and rooms dataframes + """ + files = await self.get_dataset_files(dataset_id, user_id) + + courses_df = files["courses"] + enrollments_df = files["enrollments"] + + filtered_courses_df, allowed_crns = self._filter_nonzero_enrollment( + courses_df + ) + + # If we couldn't determine CRNs/columns, keep enrollments as-is. + filtered_enrollments_df = ( + self._filter_by_allowed_crns(enrollments_df, allowed_crns) + if allowed_crns is not None + else enrollments_df + ) + + return { + "courses": filtered_courses_df, + "enrollments": filtered_enrollments_df, + "rooms": files["rooms"], + } + + def _filter_nonzero_enrollment( + self, courses_df: pd.DataFrame + ) -> tuple[pd.DataFrame, set[str] | None]: + """ + Filter the courses DataFrame to remove rows where Total_Enrollment == 0. + + Returns: + (filtered_df, allowed_crns) + """ + try: + schema, column_mapping = CSVSchemaDetector.detect_schema_version( + courses_df, "courses" + ) + except Exception: + # If schema detection fails, don't change behavior. + return courses_df.copy(), None + + canonical_to_csv = {canonical: csv for csv, canonical in column_mapping.items()} + enrollment_col = canonical_to_csv.get("Total_Enrollment") + crn_col = canonical_to_csv.get("Course_Reference_Number") + if not enrollment_col or not crn_col: + return courses_df.copy(), None + + col_defs = {cd.canonical_name: cd for cd in schema} + enrollment_transformer = col_defs.get("Total_Enrollment").transformer if col_defs.get("Total_Enrollment") else None + crn_transformer = col_defs.get("Course_Reference_Number").transformer if col_defs.get("Course_Reference_Number") else None + + enrollment_series = courses_df[enrollment_col] + if enrollment_transformer: + enrollment_series = enrollment_series.apply(enrollment_transformer) + + # Keep only nonzero enrollments; treat None/NaN as zero for this filter. + try: + nonzero_mask = enrollment_series.fillna(0).astype(int) != 0 + except Exception: + nonzero_mask = enrollment_series.fillna(0) != 0 + + filtered_df = courses_df.loc[nonzero_mask].copy() + + crn_series = filtered_df[crn_col] + if crn_transformer: + crn_series = crn_series.apply(crn_transformer) + + allowed_crns = {crn for crn in crn_series.tolist() if crn} + return filtered_df, allowed_crns + + def _filter_by_allowed_crns( + self, enrollments_df: pd.DataFrame, allowed_crns: set[str] + ) -> pd.DataFrame: + """ + Filter enrollments to only those whose CRN is in allowed_crns. + + This keeps enrollments consistent with a temporarily filtered course list. + """ + if not allowed_crns: + return enrollments_df.copy() + + try: + schema, column_mapping = CSVSchemaDetector.detect_schema_version( + enrollments_df, "enrollments" + ) + except Exception: + return enrollments_df.copy() + + canonical_to_csv = {canonical: csv for csv, canonical in column_mapping.items()} + crn_col = canonical_to_csv.get("Course_Reference_Number") + if not crn_col: + return enrollments_df.copy() + + col_defs = {cd.canonical_name: cd for cd in schema} + crn_transformer = col_defs.get("Course_Reference_Number").transformer if col_defs.get("Course_Reference_Number") else None + + crn_series = enrollments_df[crn_col] + if crn_transformer: + crn_series = crn_series.apply(crn_transformer) + + mask = crn_series.isin(allowed_crns) + return enrollments_df.loc[mask].copy() + async def _download_and_parse(self, file_entry: dict) -> tuple[str, pd.DataFrame]: """Download one file and parse it.""" file_type = file_entry["type"] diff --git a/backend/src/services/schedule/service.py b/backend/src/services/schedule/service.py index f888f72..b09db1d 100644 --- a/backend/src/services/schedule/service.py +++ b/backend/src/services/schedule/service.py @@ -110,6 +110,9 @@ async def generate_schedule( # 3. Load course merges (if any) - synchronous call merges = self.dataset_service.get_merges(dataset_id, user_id) or {} + # 3.5 Drop zero-enrollment courses and get updated merges + merges = self.dataset_service.drop_zero_enrollment(dataset_id, user_id) + # 4. Build scheduling dataset and run algorithm scheduling_dataset = DatasetFactory.from_dataframes_to_scheduling_dataset( courses_df=files["courses"],