diff --git a/airflow/dags/DAG_PHT_run_train.py b/airflow/dags/DAG_PHT_run_train.py index 31e4d3e..333c066 100644 --- a/airflow/dags/DAG_PHT_run_train.py +++ b/airflow/dags/DAG_PHT_run_train.py @@ -273,6 +273,10 @@ def execute_container(train_state): raise ValueError( f"Invalid gpu configuration: {gpu_config}. Must be a list of integers or 'all'" ) + elif isinstance(gpu_config, int): + device_request = docker.types.DeviceRequest( + device_ids=[str(gpu_config)], capabilities=[["gpu"]] + ) else: raise ValueError( f"Invalid gpu configuration: {gpu_config}. Must be a list of integers or 'all'"