Skip to content

Commit

Permalink
feat: implement failure handling for pipeline (#1) (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
aballiet authored Dec 19, 2024
1 parent f7f586a commit 19127b5
Show file tree
Hide file tree
Showing 24 changed files with 307 additions and 60 deletions.
4 changes: 3 additions & 1 deletion bizon/destinations/bigquery/src/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def unnest_data(df_destination_records: pl.DataFrame, record_schema: list[BigQue
"""Unnest the source_data field into separate columns"""

# Check if the schema matches the expected schema
source_data_fields = pl.DataFrame(df_destination_records['source_data'].str.json_decode()).schema["source_data"].fields
source_data_fields = (
pl.DataFrame(df_destination_records["source_data"].str.json_decode()).schema["source_data"].fields
)

record_schema_fields = [col.name for col in record_schema]

Expand Down
15 changes: 10 additions & 5 deletions bizon/destinations/bigquery_streaming/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
)


class TimePartitioning(str, Enum):
class TimePartitioningWindow(str, Enum):
DAY = "DAY"
HOUR = "HOUR"
MONTH = "MONTH"
YEAR = "YEAR"


class TimePartitioning(BaseModel):
type: TimePartitioningWindow = Field(default=TimePartitioningWindow.DAY, description="Time partitioning type")
field: Optional[str] = Field(
"_bizon_loaded_at", description="Field to partition by. You can use a transformation to create this field."
)


class BigQueryAuthentication(BaseModel):
service_account_key: str = Field(
description="Service Account Key JSON string. If empty it will be infered",
Expand All @@ -33,10 +40,8 @@ class BigQueryStreamingConfigDetails(AbstractDestinationDetailsConfig):
default=None, description="Table ID, if not provided it will be inferred from source name"
)
time_partitioning: Optional[TimePartitioning] = Field(
default=TimePartitioning.DAY, description="BigQuery Time partitioning type"
)
time_partitioning_field: Optional[str] = Field(
"_bizon_loaded_at", description="Field to partition by. You can use a transformation to create this field."
default=TimePartitioning(type=TimePartitioningWindow.DAY, field="_bizon_loaded_at"),
description="BigQuery Time partitioning type",
)
authentication: Optional[BigQueryAuthentication] = None
bq_max_rows_per_request: Optional[int] = Field(30000, description="Max rows per buffer streaming request.")
Expand Down
7 changes: 6 additions & 1 deletion bizon/destinations/bigquery_streaming/src/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def load_to_bigquery_via_streaming(self, df_destination_records: pl.DataFrame) -
schema = self.get_bigquery_schema()
table = bigquery.Table(self.table_id, schema=schema)
time_partitioning = TimePartitioning(
field=self.config.time_partitioning_field, type_=self.config.time_partitioning
field=self.config.time_partitioning.field, type_=self.config.time_partitioning.type
)
table.time_partitioning = time_partitioning

Expand All @@ -136,6 +136,11 @@ def load_to_bigquery_via_streaming(self, df_destination_records: pl.DataFrame) -
for row in df_destination_records["source_data"].str.json_decode().to_list()
]
else:
df_destination_records = df_destination_records.with_columns(
pl.col("bizon_extracted_at").dt.strftime("%Y-%m-%d %H:%M:%S").alias("bizon_extracted_at"),
pl.col("bizon_loaded_at").dt.strftime("%Y-%m-%d %H:%M:%S").alias("bizon_loaded_at"),
pl.col("source_timestamp").dt.strftime("%Y-%m-%d %H:%M:%S").alias("source_timestamp"),
)
df_destination_records = df_destination_records.rename(
{
"bizon_id": "_bizon_id",
Expand Down
2 changes: 1 addition & 1 deletion bizon/destinations/bigquery_streaming/src/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def map_bq_type_to_field_descriptor(bq_type: str) -> int:
"DATE": FieldDescriptorProto.TYPE_STRING, # DATE -> TYPE_STRING
"DATETIME": FieldDescriptorProto.TYPE_STRING, # DATETIME -> TYPE_STRING
"TIME": FieldDescriptorProto.TYPE_STRING, # TIME -> TYPE_STRING
"TIMESTAMP": FieldDescriptorProto.TYPE_INT64, # TIMESTAMP -> TYPE_INT64 (Unix epoch time)
"TIMESTAMP": FieldDescriptorProto.TYPE_STRING, # TIMESTAMP -> TYPE_INT64 (Unix epoch time)
"RECORD": FieldDescriptorProto.TYPE_MESSAGE, # RECORD -> TYPE_MESSAGE (nested message)
}

Expand Down
19 changes: 13 additions & 6 deletions bizon/destinations/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator


class DestinationTypes(str, Enum):
Expand Down Expand Up @@ -33,19 +33,26 @@ class AbstractDestinationDetailsConfig(BaseModel):
description="Maximum time in seconds for buffering after which the records will be written to the destination. Set to 0 to deactivate the timeout buffer check.", # noqa
)

unnest: Optional[bool] = Field(
default=False,
description="Unnest the data before writing to the destination. Schema should be provided in the model_config.",
)

record_schema: Optional[list[DestinationColumn]] = Field(
default=None, description="Schema for the records. Required if unnest is set to true."
)

unnest: bool = Field(
default=False,
description="Unnest the data before writing to the destination. Schema should be provided in the model_config.",
)

authentication: Optional[BaseModel] = Field(
description="Authentication configuration for the destination, if needed", default=None
)

@field_validator("unnest", mode="before")
def validate_record_schema_if_unnest(cls, value, values):
if bool(value) and values.data.get("record_schema") is None:
raise ValueError("A `record_schema` must be provided if `unnest` is set to True.")

return value


class AbstractDestinationConfig(BaseModel):
# Forbid extra keys in the model
Expand Down
15 changes: 14 additions & 1 deletion bizon/destinations/file/src/destination.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Tuple

import polars as pl
Expand All @@ -23,5 +24,17 @@ def delete_table(self) -> bool:
return True

def write_records(self, df_destination_records: pl.DataFrame) -> Tuple[bool, str]:
df_destination_records.write_ndjson(self.config.filepath)

if self.config.unnest:

schema_keys = set([column.name for column in self.config.record_schema])

with open(self.config.filepath, "a") as f:
for value in df_destination_records["source_data"].str.json_decode().to_list():
assert set(value.keys()) == schema_keys, "Keys do not match the schema"
f.write(f"{json.dumps(value)}\n")

else:
df_destination_records.write_ndjson(self.config.filepath)

return True, ""
6 changes: 5 additions & 1 deletion bizon/engine/pipeline/consumer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import multiprocessing
import threading
from abc import ABC, abstractmethod
from typing import Union

from bizon.destinations.destination import AbstractDestination
from bizon.engine.pipeline.models import PipelineReturnStatus
from bizon.engine.queue.config import AbstractQueueConfig
from bizon.transform.transform import Transform

Expand All @@ -12,5 +16,5 @@ def __init__(self, config: AbstractQueueConfig, destination: AbstractDestination
self.transform = transform

@abstractmethod
def run(self):
def run(self, stop_event: Union[multiprocessing.Event, threading.Event]) -> PipelineReturnStatus:
pass
6 changes: 5 additions & 1 deletion bizon/engine/pipeline/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from enum import Enum


class PipelineReturnStatus(Enum):
class PipelineReturnStatus(str, Enum):
"""Producer error types"""

SUCCESS = "success"
ERROR = "error"
KILLED_BY_RUNNER = "killed_by_runner"
QUEUE_ERROR = "queue_error"
SOURCE_ERROR = "source_error"
BACKEND_ERROR = "backend_error"
TRANSFORM_ERROR = "transform_error"
DESTINATION_ERROR = "destination_error"
10 changes: 8 additions & 2 deletions bizon/engine/pipeline/producer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ast
import multiprocessing
import threading
import traceback
from datetime import datetime
from time import sleep
from typing import Tuple
from typing import Tuple, Union

from loguru import logger
from pytz import UTC
Expand Down Expand Up @@ -99,7 +101,7 @@ def is_queue_full(self, cursor: Cursor) -> Tuple[bool, int, int]:

return False, queue_size, approximate_nb_records_in_queue

def run(self, job_id: int):
def run(self, job_id: int, stop_event: Union[multiprocessing.Event, threading.Event]) -> PipelineReturnStatus:

return_value: PipelineReturnStatus = PipelineReturnStatus.SUCCESS

Expand Down Expand Up @@ -128,6 +130,10 @@ def run(self, job_id: int):

while not cursor.is_finished:

if stop_event.is_set():
logger.info("Stop event is set, terminating producer ...")
return PipelineReturnStatus.KILLED_BY_RUNNER

timestamp_start_iteration = datetime.now(tz=UTC)

# Handle the case where last cursor already reach max_iterations
Expand Down
51 changes: 37 additions & 14 deletions bizon/engine/queue/adapters/python_queue/consumer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import multiprocessing
import threading
import traceback
from typing import Union

from loguru import logger

from bizon.destinations.destination import AbstractDestination
from bizon.engine.pipeline.consumer import AbstractQueueConsumer
from bizon.engine.pipeline.models import PipelineReturnStatus
from bizon.engine.queue.queue import QUEUE_TERMINATION, AbstractQueue, QueueMessage
from bizon.transform.transform import Transform

Expand All @@ -15,30 +21,47 @@ def __init__(
super().__init__(config, destination=destination, transform=transform)
self.queue = queue

def run(self) -> None:
def run(self, stop_event: Union[threading.Event, multiprocessing.Event]) -> PipelineReturnStatus:
while True:

if stop_event.is_set():
logger.info("Stop event is set, closing consumer ...")
return PipelineReturnStatus.KILLED_BY_RUNNER
# Retrieve the message from the queue
queue_message: QueueMessage = self.queue.get()

# Apply the transformation
df_source_records = self.transform.apply_transforms(df_source_records=queue_message.df_source_records)
try:
df_source_records = self.transform.apply_transforms(df_source_records=queue_message.df_source_records)
except Exception as e:
logger.error(f"Error applying transformation: {e}")
logger.error(traceback.format_exc())
return PipelineReturnStatus.TRANSFORM_ERROR

if queue_message.signal == QUEUE_TERMINATION:
logger.info("Received termination signal, waiting for destination to close gracefully ...")
try:
if queue_message.signal == QUEUE_TERMINATION:
logger.info("Received termination signal, waiting for destination to close gracefully ...")
self.destination.write_records_and_update_cursor(
df_source_records=df_source_records,
iteration=queue_message.iteration,
extracted_at=queue_message.extracted_at,
pagination=queue_message.pagination,
last_iteration=True,
)
break
except Exception as e:
logger.error(f"Error writing records to destination: {e}")
return PipelineReturnStatus.DESTINATION_ERROR

try:
self.destination.write_records_and_update_cursor(
df_source_records=df_source_records,
iteration=queue_message.iteration,
extracted_at=queue_message.extracted_at,
pagination=queue_message.pagination,
last_iteration=True,
)
break

self.destination.write_records_and_update_cursor(
df_source_records=df_source_records,
iteration=queue_message.iteration,
extracted_at=queue_message.extracted_at,
pagination=queue_message.pagination,
)
except Exception as e:
logger.error(f"Error writing records to destination: {e}")
return PipelineReturnStatus.DESTINATION_ERROR

self.queue.task_done()
return PipelineReturnStatus.SUCCESS
16 changes: 14 additions & 2 deletions bizon/engine/runner/adapters/process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import concurrent.futures
import time
import traceback

from loguru import logger

Expand Down Expand Up @@ -68,8 +69,19 @@ def run(self):
result_producer = future_producer.result()
logger.info(f"Producer process stopped running with result: {result_producer}")

if result_producer.SUCCESS:
logger.info("Producer thread has finished successfully, will wait for consumer to finish ...")
else:
logger.error("Producer thread failed, stopping consumer ...")
executor.shutdown(wait=False)

if not future_consumer.running():
result_consumer = future_consumer.result()
logger.info(f"Consumer process stopped running with result: {result_consumer}")
try:
future_consumer.result()
except Exception as e:
logger.error(f"Consumer thread stopped running with error {e}")
logger.error(traceback.format_exc())
finally:
executor.shutdown(wait=False)

return True
35 changes: 26 additions & 9 deletions bizon/engine/runner/adapters/thread.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import concurrent.futures
import time
import traceback
from threading import Event

from loguru import logger

from bizon.common.models import BizonConfig
from bizon.engine.pipeline.models import PipelineReturnStatus
from bizon.engine.runner.config import RunnerStatus
from bizon.engine.runner.runner import AbstractRunner


Expand All @@ -25,7 +27,7 @@ def get_kwargs(self):

return extra_kwargs

def run(self) -> bool:
def run(self) -> RunnerStatus:
"""Run the pipeline with dedicated threads for source and destination"""

extra_kwargs = self.get_kwargs()
Expand All @@ -35,6 +37,10 @@ def run(self) -> bool:
result_producer = None
result_consumer = None

# Start the producer and consumer events
producer_stop_event = Event()
consumer_stop_event = Event()

extra_kwargs = self.get_kwargs()

with concurrent.futures.ThreadPoolExecutor(
Expand All @@ -46,6 +52,7 @@ def run(self) -> bool:
self.bizon_config,
self.config,
job.id,
producer_stop_event,
**extra_kwargs,
)
logger.info("Producer thread has started ...")
Expand All @@ -56,6 +63,7 @@ def run(self) -> bool:
AbstractRunner.instanciate_and_run_consumer,
self.bizon_config,
job.id,
consumer_stop_event,
**extra_kwargs,
)
logger.info("Consumer thread has started ...")
Expand All @@ -68,14 +76,23 @@ def run(self) -> bool:
self._is_running = False

if not future_producer.running():
result_producer = future_producer.result()
result_producer: PipelineReturnStatus = future_producer.result()
logger.info(f"Producer thread stopped running with result: {result_producer}")

if result_producer.SUCCESS:
logger.info("Producer thread has finished successfully, will wait for consumer to finish ...")
else:
logger.error("Producer thread failed, stopping consumer ...")
consumer_stop_event.set()

if not future_consumer.running():
try:
future_consumer.result()
except Exception as e:
logger.error(f"Consumer thread stopped running with error {e}")
logger.error(traceback.format_exc())
result_consumer = future_consumer.result()
logger.info(f"Consumer thread stopped running with result: {result_consumer}")

if result_consumer == PipelineReturnStatus.SUCCESS:
logger.info("Consumer thread has finished successfully")
else:
logger.error("Consumer thread failed, stopping producer ...")
producer_stop_event.set()

return True
return RunnerStatus(producer=future_producer.result(), consumer=future_consumer.result())
Loading

0 comments on commit 19127b5

Please sign in to comment.