diff --git a/mostlyai/sdk/_data/non_context.py b/mostlyai/sdk/_data/non_context.py index 114dd12a..25baff45 100644 --- a/mostlyai/sdk/_data/non_context.py +++ b/mostlyai/sdk/_data/non_context.py @@ -85,6 +85,9 @@ TOP_K = None TOP_P = 0.95 QUOTA_PENALTY_FACTOR = 0.05 +FK_MATCHING_PARENT_BATCH_SIZE = 5_000 +FK_MATCHING_CHILD_BATCH_SIZE = 5_000 + # Supported Encoding Types FK_MODEL_ENCODING_TYPES = [ @@ -1131,7 +1134,7 @@ def initialize_remaining_capacity( # Generate children counts using engine with parent data as seed # The engine will predict the __CHILDREN_COUNT__ column based on parent features - _LOG.info(f"Generating cardinality predictions using engine for {len(parent_data)} parents") + _LOG.info(f"Generating cardinality predictions for {len(parent_data)} parents") engine.generate( seed_data=parent_data, diff --git a/mostlyai/sdk/_local/execution/step_finalize_generation.py b/mostlyai/sdk/_local/execution/step_finalize_generation.py index 2bd39e07..aff8ecd0 100644 --- a/mostlyai/sdk/_local/execution/step_finalize_generation.py +++ b/mostlyai/sdk/_local/execution/step_finalize_generation.py @@ -26,6 +26,8 @@ from mostlyai.sdk._data.file.table.csv import CsvDataTable from mostlyai.sdk._data.file.table.parquet import ParquetDataTable from mostlyai.sdk._data.non_context import ( + FK_MATCHING_CHILD_BATCH_SIZE, + FK_MATCHING_PARENT_BATCH_SIZE, add_context_parent_data, assign_non_context_fks_randomly, initialize_remaining_capacity, @@ -42,10 +44,6 @@ _LOG = logging.getLogger(__name__) -# FK processing constants -FK_MIN_CHILDREN_BATCH_SIZE = 10 -FK_PARENT_BATCH_SIZE = 1_000 - def execute_step_finalize_generation( *, @@ -312,49 +310,15 @@ def process_table_with_random_fk_assignment( write_batch_outputs(processed_data, table_name, chunk_idx, pqt_path, csv_path) -def calculate_optimal_child_batch_size_for_relation( - parent_key_count: int, - children_row_count: int, - parent_batch_size: int, - relation_name: str, -) -> int: - """Calculate optimal child batch size for a specific FK relationship.""" - num_parent_batches = max(1, math.ceil(parent_key_count / parent_batch_size)) - - # ideal batch size for full parent utilization - ideal_batch_size = children_row_count // num_parent_batches - - # apply minimum batch size constraint - optimal_batch_size = max(ideal_batch_size, FK_MIN_CHILDREN_BATCH_SIZE) - - # log utilization metrics - num_child_batches = children_row_count // optimal_batch_size - parent_utilization = min(num_child_batches / num_parent_batches * 100, 100) - - _LOG.info( - f"[{relation_name}] Batch size optimization | " - f"total_children: {children_row_count} | " - f"parent_size: {parent_key_count} | " - f"parent_batch_size: {parent_batch_size} | " - f"parent_batches: {num_parent_batches} | " - f"ideal_child_batch: {ideal_batch_size} | " - f"optimal_child_batch: {optimal_batch_size} | " - f"parent_utilization: {parent_utilization:.1f}%" - ) - - return optimal_batch_size - - def process_table_with_fk_models( *, table_name: str, schema: Schema, pqt_path: Path, csv_path: Path | None, - parent_batch_size: int = FK_PARENT_BATCH_SIZE, job_workspace_dir: Path, ) -> None: - """Process table with ML model-based FK assignment using logical child batches.""" + """Process table with ML model-based FK assignment using fixed batch sizes.""" fk_models_workspace_dir = job_workspace_dir / "FKModelsStore" / table_name non_ctx_relations = [rel for rel in schema.non_context_relations if rel.child.table == table_name] @@ -374,21 +338,6 @@ def process_table_with_fk_models( do_coerce_dtypes=True, ) - # Calculate optimal batch size for each relationship - relation_batch_sizes = {} - for relation in non_ctx_relations: - parent_table_name = relation.parent.table - parent_key_count = len(parent_keys_cache[parent_table_name]) - relation_name = f"{relation.child.table}.{relation.child.column}->{parent_table_name}" - - optimal_batch_size = calculate_optimal_child_batch_size_for_relation( - parent_key_count=parent_key_count, - children_row_count=children_table.row_count, - parent_batch_size=parent_batch_size, - relation_name=relation_name, - ) - relation_batch_sizes[relation] = optimal_batch_size - # Initialize remaining capacity for all relations # At this point, both FK models and cardinality models are guaranteed to exist # (checked by are_fk_models_available) @@ -401,8 +350,7 @@ def process_table_with_fk_models( parent_keys_df = parent_keys_cache[parent_table_name] parent_table = parent_tables[parent_table_name] - _LOG.info(f"Using Engine-based Cardinality Model for {relation.child.table}.{relation.child.column}") - # Use Engine-based Cardinality Model to predict capacities + _LOG.info(f"Using Cardinality Model for {relation.child.table}.{relation.child.column}") parent_data = parent_table.read_data( where={pk_col: parent_keys_df[pk_col].tolist()}, do_coerce_dtypes=True, @@ -421,35 +369,45 @@ def process_table_with_fk_models( parent_table_name = relation.parent.table parent_table = parent_tables[parent_table_name] parent_pk = relation.parent.column - optimal_batch_size = relation_batch_sizes[relation] - relation_name = f"{relation.child.table}.{relation.child.column}->{parent_table_name}" + parent_keys_df = parent_keys_cache[parent_table_name] + parent_size = len(parent_keys_df) + child_size = len(chunk_data) + parent_batch_size = min(FK_MATCHING_PARENT_BATCH_SIZE, parent_size) + child_batch_size = min(FK_MATCHING_CHILD_BATCH_SIZE, child_size) + num_child_batches = math.ceil(child_size / child_batch_size) + total_parent_samples_needed = num_child_batches * parent_batch_size - _LOG.info(f" Processing relationship {relation_name} with batch size {optimal_batch_size}") + _LOG.info( + f"Processing relationship {relation.child.table}.{relation.child.column}->{parent_table_name} with " + f"parent_batch_size={parent_batch_size}, child_batch_size={child_batch_size}, " + f"num_child_batches={num_child_batches}" + ) - parent_keys_df = parent_keys_cache[parent_table_name] + # sample enough parent data to cover all child batches in this chunk + sampled_parent_keys = parent_keys_df.sample( + n=total_parent_samples_needed, replace=total_parent_samples_needed > parent_size + )[parent_pk].tolist() + parent_data_for_chunk = parent_table.read_data( + where={parent_pk: sampled_parent_keys}, + columns=parent_table.columns, + do_coerce_dtypes=True, + ) processed_batches = [] - - for batch_start in range(0, len(chunk_data), optimal_batch_size): - batch_end = min(batch_start + optimal_batch_size, len(chunk_data)) + for batch_idx, batch_start in enumerate(range(0, child_size, child_batch_size)): + batch_end = min(batch_start + child_batch_size, child_size) batch_data = chunk_data.iloc[batch_start:batch_end].copy() - - sampled_parent_keys = parent_keys_df.sample( - n=parent_batch_size, replace=len(parent_keys_df) < parent_batch_size - )[parent_pk].tolist() - - parent_data = parent_table.read_data( - where={parent_pk: sampled_parent_keys}, - columns=parent_table.columns, - do_coerce_dtypes=True, - ) - batch_data = add_context_parent_data( tgt_data=batch_data, tgt_table=children_table, schema=schema, ) + # slice the appropriate parent batch from the pre-fetched data + parent_slice_start = batch_idx * parent_batch_size + parent_slice_end = min(parent_slice_start + parent_batch_size, len(parent_data_for_chunk)) + parent_data = parent_data_for_chunk.iloc[parent_slice_start:parent_slice_end] + assert relation in remaining_capacity processed_batch = match_non_context( fk_models_workspace_dir=fk_models_workspace_dir,