Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced data generating #15

Merged
merged 4 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion filling/check_constraint_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,6 @@ def regexp_like(self, value: str, pattern: str) -> bool:
# Remove outer quotes from pattern if present
if pattern.startswith("'") and pattern.endswith("'"):
pattern = pattern[1:-1]
pattern = pattern.encode('utf-8').decode('unicode_escape')
if not isinstance(value, str):
value = str(value)
try:
Expand Down
141 changes: 109 additions & 32 deletions filling/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,25 @@ def generate_composite_primary_keys(self, table: str, num_rows: int):
def generate_primary_keys(self, table: str, row: dict):
"""
Assign unique primary key values to a given row in a specified table.

Args:
table (str): The name of the table where the row resides.
row (dict): The dictionary representing the row data to be populated with primary key values.
If the PK column is_serial=True, we treat it as auto-increment integer.
Otherwise, we might generate something else (e.g., random strings).
"""
pk_columns = self.tables[table].get('primary_key', [])
for pk in pk_columns:
row[pk] = self.primary_keys[table][pk]
self.primary_keys[table][pk] += 1
for pk_col in pk_columns:
col_info = self.get_column_info(table, pk_col)
if not col_info:
continue
col_type = col_info['type'].upper()

# is_serial or numeric => auto-increment
if col_info.get("is_serial") or re.search(r'(INT|BIGINT|SMALLINT|DECIMAL|NUMERIC)', col_type):
row[pk_col] = self.primary_keys[table][pk_col]
self.primary_keys[table][pk_col] += 1
else:
# It's a non-numeric PK, so generate text or whatever is suitable
length_match = re.search(r'\((\d+)\)', col_type)
length = int(length_match.group(1)) if length_match else 5
row[pk_col] = self.fake.lexify(text='?' * length)

def enforce_constraints(self):
"""
Expand All @@ -257,28 +267,77 @@ def enforce_constraints(self):

def assign_foreign_keys(self, table: str, row: dict):
"""
Automatically assign foreign key values to a table row based on established relationships.

Args:
table (str): The name of the table where the row resides.
row (dict): The dictionary representing the row data to be populated with foreign key values.
Automatically assign foreign key values to a table row based on
established relationships, including support for composite keys
and partially pre-filled columns.
"""
fks = self.tables[table].get('foreign_keys', [])
for fk in fks:
fk_columns = fk['columns']
ref_table = fk['ref_table']
ref_columns = fk['ref_columns']
if ref_table in self.generated_data and ref_columns:
# Skip assigning if FK columns are already set (e.g., in composite PKs)
if all(col in row for col in fk_columns):
fk_columns = fk['columns'] # e.g. ['row', 'seat', 'theater_id']
ref_table = fk['ref_table'] # e.g. 'Seats'
ref_columns = fk['ref_columns'] # e.g. ['row', 'seat', 'theater_id']

# We'll check child's existing FK columns to see if they're set
child_values = [row.get(fc) for fc in fk_columns]
all_set = all(v is not None for v in child_values)
partially_set = any(v is not None for v in child_values) and not all_set

# Potential parent rows
parent_data = self.generated_data[ref_table]

# ─────────────────────────────────────────
# 1) If all columns are already set, see if there's a matching parent row
# ─────────────────────────────────────────
if all_set:
matching_parents = [
p for p in parent_data
if all(p[rc] == row[fc] for rc, fc in zip(ref_columns, fk_columns))
]
if matching_parents:
# We do nothing: child's columns already match a valid parent
continue
ref_record = random.choice(self.generated_data[ref_table])
for idx, fk_col in enumerate(fk_columns):
ref_col = ref_columns[idx]
row[fk_col] = ref_record[ref_col]
else:
for fk_col in fk_columns:
row[fk_col] = None
else:
# No match found → pick a valid random parent & overwrite child's columns
chosen_parent = random.choice(parent_data)
for rc, fc in zip(ref_columns, fk_columns):
row[fc] = chosen_parent[rc]
continue

# ─────────────────────────────────────────
# 2) If *some* columns are set (partial), do a partial match
# ─────────────────────────────────────────
if partially_set:
possible_parents = []
for p in parent_data:
is_candidate = True
for rc, fc in zip(ref_columns, fk_columns):
child_val = row.get(fc)
# If child_val is set, parent must match
if child_val is not None and p[rc] != child_val:
is_candidate = False
break
if is_candidate:
possible_parents.append(p)

if not possible_parents:
# No partial match => pick random parent
chosen_parent = random.choice(parent_data)
else:
# Among partial matches, pick one at random
chosen_parent = random.choice(possible_parents)

# Fill any missing columns from the chosen parent
for rc, fc in zip(ref_columns, fk_columns):
if row.get(fc) is None:
row[fc] = chosen_parent[rc]
continue

# ─────────────────────────────────────────
# 3) If none of the columns are set, pick a random parent row
# ─────────────────────────────────────────
chosen_parent = random.choice(parent_data)
for rc, fc in zip(ref_columns, fk_columns):
row[fc] = chosen_parent[rc]

def fill_remaining_columns(self, table: str, row: dict):
"""
Expand Down Expand Up @@ -308,8 +367,15 @@ def fill_remaining_columns(self, table: str, row: dict):
if col_name in constraint:
col_constraints.append(constraint)

# Generate the column value
row[col_name] = self.generate_column_value(table, column, row, constraints=col_constraints)
# If is_serial but not a PK, handle auto-increment:
if column.get('is_serial'):
# If we haven't set up a separate counter for this col, do so now
if col_name not in self.primary_keys[table]:
self.primary_keys[table][col_name] = 1
row[col_name] = self.primary_keys[table][col_name]
self.primary_keys[table][col_name] += 1
else:
row[col_name] = self.generate_column_value(table, column, row, constraints=col_constraints)

def enforce_not_null_constraints(self, table: str, row: dict):
"""
Expand Down Expand Up @@ -398,7 +464,6 @@ def generate_column_value(
if numeric_ranges:
return generate_numeric_value(numeric_ranges, col_type)

# Default data generation based on column type
return self.generate_value_based_on_type(col_type)

def generate_value_based_on_type(self, col_type: str):
Expand All @@ -411,16 +476,28 @@ def generate_value_based_on_type(self, col_type: str):
Returns:
Any: A synthetic value appropriate for the specified data type.
"""
is_unsigned = False
if col_type.upper().startswith('U'):
is_unsigned = True
col_type = col_type[1:] # Remove the leading 'U' so the rest of the logic matches e.g. 'INT', 'BIGINT'

col_type = col_type.upper()

if re.match(r'.*\b(INT|INTEGER|SMALLINT|BIGINT)\b.*', col_type):
return random.randint(1, 10000)
min_val = 0 if is_unsigned else -10000
return random.randint(min_val, 10000)
elif re.match(r'.*\b(DECIMAL|NUMERIC)\b.*', col_type):
# Handle decimal and numeric types with precision and scale
precision, scale = 10, 2 # Default values
# Similar logic for DECIMAL if needed
precision, scale = 10, 2
match = re.search(r'\((\d+),\s*(\d+)\)', col_type)
if match:
precision, scale = int(match.group(1)), int(match.group(2))
max_value = 10 ** (precision - scale) - 1
return round(random.uniform(0, max_value), scale)

# If it's unsigned, ensure the minimum is 0
min_dec = 0.0 if is_unsigned else -9999.0 # or 0 if you prefer all positives
return round(random.uniform(min_dec, max_value), scale)

elif re.match(r'.*\b(FLOAT|REAL|DOUBLE PRECISION|DOUBLE)\b.*', col_type):
return random.uniform(0, 10000)
elif re.match(r'.*\b(BOOLEAN|BOOL)\b.*', col_type):
Expand Down
Loading
Loading