-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Labels
Milestone
Description
Hello,
I'm a Kubeflow Pipelines maintainer and I'm working on making it easier to launch training jobs from a pipeline component/step as a side project.
The high-level goals of the integration:
- Allow setting resource requests (CPU, GPU, memory), environment variables, node selectors, PVCs, etc. on the component but have it apply to the TrainJob instead (done in feat(sdk/backend): Support forwarding task configuration to external workloads pipelines#12185)
- Leverage KFP's concepts of artifacts from outputs from the training job (share a PVC between the component and the train job)
- Pass component input parameters from the pipeline step to the training job
- Abstract creating the train job with the user's input training function, stream the logs of the train job, and poll until the train job completes
Potential enhancements to the Kubeflow SDK:
- Enhance
CustomTrainer
with more options that are available in the CRD API (e.g.podSpecOverrides
) - Enhance
CustomTrainer
with an optional argument oftrain_job_patch: dict
that deep‑merges into the TrainJob spec before submission. This enables advanced users to tweak any CRD field safely without forking the SDK. - Allow setting a name for the
TrainJob
- Enhance
wait_for_job_status
withstream_logs=True
keyword argument to stream logs of the main train job pod in the background. Seestream_main_worker_logs
in the code snippet below. - A high-level single call similar to the
submit_train_job
in the code snippet below. - Encode training kwargs as JSON in an env var (e.g.,
TRAINING_CONFIG
) and decode inside the ephemeral script before calling the function. This is more likely to succeed than literally injecting the dictionary. - Set a default
HOME
environment variable to a writeable location to ensurepip installs
work when the container isn't running as root. - Ensure pip installs in the ephemeral script always use
--user
to work when the container isn't running as root. - Enhance
CustomTrainer
to accept the new KFP typedsl.TaskConfig
(feat(sdk/backend): Support forwarding task configuration to external workloads pipelines#12185) so users can leverage the Kubeflow SDK directly in the pipeline.
This is the user experienced I got working in a proof of concept:
from kfp import dsl
from typing import Optional
import kfp
import kfp.kubernetes
@dsl.component(
# Required to create a Kubeflow Trainer job
packages_to_install=["kubernetes"],
# This is a new feature that hasn't been released yet.
# Required to pass through the Kubernetes config to the Kubeflow Trainer job
task_config_passthroughs=[
dsl.TaskConfigField.RESOURCES,
dsl.TaskConfigField.KUBERNETES_TOLERATIONS,
dsl.TaskConfigField.KUBERNETES_NODE_SELECTOR,
dsl.TaskConfigField.KUBERNETES_AFFINITY,
# Passthrough the environment variables for HugginFace authentication
dsl.TaskConfigPassthrough(field=dsl.TaskConfigField.ENV, apply_to_task=True),
# Passthrough the PVCs for data sharing between the training job and the component
dsl.TaskConfigPassthrough(field=dsl.TaskConfigField.KUBERNETES_VOLUMES, apply_to_task=True),
],
)
def train_model(
# Downloads the dataset from S3
input_dataset: dsl.Input[dsl.Dataset],
# Declares the output model artifact for uploading to S3
output_model: dsl.Output[dsl.Model],
model_name: str,
pvc_path: str,
trainer_runtime: str,
num_nodes: int = 2,
epochs: int = 10,
learning_rate: float = 3e-4,
kubernetes_config: dsl.TaskConfig = None,
):
import json
import os
import shutil
def train_model_func(
learning_rate: float,
model_name: str,
dataset_path: str,
epochs: int,
output_model_path: str,
):
print("Training code goes here leverage the dataset copied to the PVC")
print("Copying dataset to PVC...")
dataset_path = os.path.join(pvc_path, "dataset", "train")
os.makedirs(dataset_path, exist_ok=True)
shutil.copytree(
input_dataset.path,
dataset_path,
dirs_exist_ok=True,
)
print(f"Dataset copied successfully from {input_dataset.path} to {dataset_path}")
print("=== Starting TrainJob creation process ===")
# TODO: This doesn't actually exist in the SDK today
from kfp.training_utils import submit_training_job
model_source = os.path.join(pvc_path, "adapter")
# Creates the `TrainJob`, streams the logs from the main trainer worker pod, polls until the pod is complete
submit_training_job(
train_func=train_model_func,
runtime_ref=trainer_runtime,
num_nodes=num_nodes,
packages_to_install=["transformers", "peft", "accelerate", "trl"],
kubernetes_config=kubernetes_config,
# All these parameters are passed to the training job
learning_rate=learning_rate,
model_name=model_name,
dataset_path=dataset_path,
epochs=epochs,
output_model_path=model_source,
)
print("Processing training results...")
print("Copying trained model from PVC to Kubeflow output path...")
print(f"Model source: {model_source}")
print(f"Destination: {output_model.path}")
output_model.name = f"{model_name}-adapter"
shutil.copytree(model_source, output_model.path, dirs_exist_ok=True)
print(f"Model copied successfully from {model_source} to {output_model.path}")
@dsl.pipeline(
name="Train and evaluate",
description="Provides complete training and evaluation of an LLM model",
)
def train_model_pipeline(
model_name: str = "meta-llama/Llama-3.2-3B-Instruct",
train_epochs: int = 10,
train_learning_rate: float = 3e-4,
train_split_ratio: float = 0.8,
train_num_nodes: int = 2,
train_node_cpu_request: str = "2",
train_node_gpu_request: str = "1",
train_node_memory_request: str = "100Gi",
trainer_runtime: str = "torch-distributed",
):
# Definition of prepare_dataset not shown for brevity
prepare_dataset_op = (
prepare_dataset(
train_split_ratio=train_split_ratio,
)
.set_caching_options(enable_caching=False)
.set_retry(3)
)
train_model_op = (
train_model(
model_name=model_name,
epochs=train_epochs,
input_dataset=prepare_dataset_op.outputs["yoda_train_dataset"],
# Use the workspace, which is a PVC for the lifetime of the pipeline run
pvc_path=dsl.WORKSPACE_PATH_PLACEHOLDER,
learning_rate=train_learning_rate,
num_nodes=train_num_nodes,
trainer_runtime=trainer_runtime,
)
.after(prepare_dataset_op)
.set_caching_options(enable_caching=False)
.set_cpu_request(train_node_cpu_request)
.set_memory_request(train_node_memory_request)
.set_cpu_limit(train_node_cpu_request)
.set_memory_limit(train_node_memory_request)
.set_accelerator_type("nvidia.com/gpu")
.set_accelerator_limit(train_node_gpu_request)
)
kfp.kubernetes.use_secret_as_env(
task=train_model_op,
secret_name="hf-token",
secret_key_to_env={"HF_TOKEN": "HF_TOKEN"},
)
Here are the contents of train_utils.py
with the submit_train_job
function:
from typing import Callable, List
from kfp import dsl
def _deep_merge_dicts(base: dict, patch: dict) -> dict:
"""Deep-merge patch into base, merging dicts and intelligently merging
lists.
- Dicts: merged recursively.
- Lists of dicts with a "name" key: merged by name with deep-merge of items.
- Other lists: append non-duplicates while preserving order.
"""
def _is_named_dict(value):
return isinstance(value, dict) and "name" in value
def _merge_lists(base_list, patch_list):
patch_has_named = any(_is_named_dict(item) for item in patch_list)
if patch_has_named:
name_to_index = {
item["name"]: idx
for idx, item in enumerate(base_list)
if _is_named_dict(item)
}
for patch_item in patch_list:
if _is_named_dict(patch_item):
item_name = patch_item["name"]
if item_name in name_to_index:
base_item = base_list[name_to_index[item_name]]
_deep_merge_dicts(base_item, patch_item)
else:
base_list.append(patch_item)
else:
if patch_item not in base_list:
base_list.append(patch_item)
return base_list
for patch_item in patch_list:
if patch_item not in base_list:
base_list.append(patch_item)
return base_list
for key, patch_value in patch.items():
base_value = base.get(key)
if isinstance(base_value, dict) and isinstance(patch_value, dict):
_deep_merge_dicts(base_value, patch_value)
elif isinstance(base_value, list) and isinstance(patch_value, list):
_merge_lists(base_value, patch_value)
else:
base[key] = patch_value
return base
def stream_main_worker_logs(api_client, namespace: str, run_name: str,
stop_event):
"""Stream logs from the main worker pod (completion index 0) in background.
Retries pod discovery and reconnects to log stream on
errors/closures until stop_event is set. No exceptions are raised;
errors are printed and retried.
"""
import time
from kubernetes import client as k8s_client # local import to avoid hard dep at module load
core_v1 = k8s_client.CoreV1Api(api_client)
label_selector = (f"jobset.sigs.k8s.io/jobset-name={run_name},"
f"batch.kubernetes.io/job-completion-index=0")
last_pod_name = None
print(
"Background log streamer started. Waiting for the main worker pod to appear..."
)
while not stop_event.is_set():
try:
pods = core_v1.list_namespaced_pod(
namespace=namespace,
label_selector=label_selector,
).items
if not pods:
time.sleep(5)
continue
pod = pods[0]
pod_name = pod.metadata.name
if pod_name != last_pod_name:
print(f"Connecting to logs of pod {pod_name}")
last_pod_name = pod_name
# Only check the pod phase. Avoid reading logs while phase is Pending/Unknown
status = getattr(pod, "status", None)
phase = getattr(status, "phase",
None) if status is not None else None
if phase not in ("Running", "Succeeded", "Failed"):
print(
f"Pod {pod_name} phase is {phase or 'unknown'}; waiting before streaming logs..."
)
time.sleep(5)
continue
try:
resp = core_v1.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
follow=True,
_preload_content=False,
)
for line in resp.stream(decode_content=True):
if stop_event.is_set():
break
try:
text = line.decode("utf-8", errors="replace").rstrip()
except AttributeError:
text = str(line).rstrip()
if text:
print(f"[{pod_name}] {text}")
except Exception as e:
print(f"Error streaming logs from pod {pod_name}: {e}")
time.sleep(5)
except Exception as e:
print(f"Error finding main worker pod: {e}")
time.sleep(5)
print("Background log streamer stopped.")
def monitor_train_job(custom_objects_api, namespace: str,
job_name: str) -> None:
"""Monitor a TrainJob until completion or failure.
Polls status every 10 seconds, printing condition updates. Returns
on completion, raises RuntimeError on failure or cancellation.
"""
import time
from kubernetes.client.rest import ApiException
check_count = 0
while True:
check_count += 1
try:
print(f"Checking job status (attempt {check_count})...")
job_status = custom_objects_api.get_namespaced_custom_object(
group="trainer.kubeflow.org",
version="v1alpha1",
namespace=namespace,
plural="trainjobs",
name=job_name,
)
status = job_status.get("status", {})
conditions = status.get("conditions", [])
print(f"Job status conditions: {conditions}")
completed = False
failed = False
for condition in conditions:
condition_type = condition.get("type", "")
condition_status = condition.get("status", "")
condition_reason = condition.get("reason", "")
condition_message = condition.get("message", "")
print(
f"Condition: type={condition_type}, status={condition_status}, reason={condition_reason}"
)
if condition_type == "Complete" and condition_status == "True":
print(
f"Training job {job_name} completed successfully: {condition_message}"
)
completed = True
break
elif condition_type == "Failed" and condition_status == "True":
print(
f"Training job {job_name} failed: {condition_message}")
failed = True
break
elif condition_type == "Cancelled" and condition_status == "True":
print(
f"Training job {job_name} was cancelled: {condition_message}"
)
failed = True
break
if completed:
return
elif failed:
raise RuntimeError(
f"Training job {job_name} failed or was cancelled")
else:
print(f"Job is still running, continuing to wait...")
except ApiException as e:
print(f"Error checking job status: {e}")
print(f"Error details: {e.body}")
print(f"Waiting 10 seconds before next check...")
time.sleep(10)
def build_train_job(
*,
runtime_ref: str,
run_name: str,
namespace: str,
num_nodes: int,
kubernetes_config: dsl.TaskConfig,
train_func: Callable,
packages_to_install: List[str] = None,
training_kwargs: dict,
train_job_patch: dict = None,
) -> dict:
"""Construct the TrainJob custom resource body.
Returns a dict ready to be submitted to the Kubernetes API.
"""
import inspect
import json
import textwrap
print("Generating command...")
func_code = inspect.getsource(train_func)
func_code = textwrap.dedent(func_code)
func_call_code = f"""
import os
import json
# Parse function arguments from environment variable
config_json = os.environ.get("TRAINING_CONFIG", "{{}}")
func_args = json.loads(config_json)
# Call the training function with parsed arguments
{train_func.__name__}(**func_args)
"""
func_code = f"{func_code}\n{func_call_code}"
packages_str = ""
if packages_to_install:
safe_packages = [str(p) for p in packages_to_install]
packages_str = (
"\n\nif ! [ -x \"$(command -v pip)\" ]; then\n"
"echo \"Installing pip...\"\n"
"python -m ensurepip || python -m ensurepip --user\n"
"fi\n\n"
"echo \"Installing Python packages...\"\n"
f"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --user --quiet --no-warn-script-location {' '.join(safe_packages)}\n"
)
install_script = f"""set -e
set -o pipefail
echo "=== Starting container setup ==="
echo "Python version: $(python --version)"
{packages_str}
echo "Creating training script..."
cat > ephemeral_component.py << 'EOF'
{func_code}
EOF
echo "Starting distributed training..."
torchrun ephemeral_component.py"""
command = ["bash", "-c", install_script]
print(f"Generated command: {command}")
print(f"Command length: {len(command)}")
print(f"Command type: {type(command)}")
env = kubernetes_config.env
if env is None:
env = []
else:
env = list(env)
env.append({
"name": "TRAINING_CONFIG",
"value": json.dumps(training_kwargs),
})
home_set = False
for env_var in env:
if env_var["name"] == "HOME":
home_set = True
break
if not home_set:
env.append({
"name": "HOME",
"value": "/tmp",
})
train_job = {
"apiVersion": "trainer.kubeflow.org/v1alpha1",
"kind": "TrainJob",
"metadata": {
"name": run_name,
"namespace": namespace
},
"spec": {
"runtimeRef": {
"name": runtime_ref
},
"trainer": {
"numNodes": num_nodes,
"resourcesPerNode": kubernetes_config.resources,
"env": env,
"command": command,
},
"podSpecOverrides": [{
"targetJobs": [{
"name": "node"
}],
"volumes": kubernetes_config.volumes,
"containers": [{
"name": "node",
"volumeMounts": kubernetes_config.volume_mounts,
}],
"nodeSelector": kubernetes_config.node_selector,
"tolerations": kubernetes_config.tolerations,
}],
},
}
if train_job_patch:
print("Merging user-provided train_job_patch into TrainJob spec...")
train_job = _deep_merge_dicts(train_job, train_job_patch)
return train_job
def submit_training_job(
train_func: Callable,
runtime_ref: str,
num_nodes: int = 1,
packages_to_install: List[str] = None,
run_name: str = None,
# TODO: Replace this with a more generic type when in Kubeflow SDK.
kubernetes_config: dsl.TaskConfig = None,
train_job_patch: dict = None,
**kwargs,
):
import logging
import uuid
from kubernetes import client as k8s_client
from kubernetes.client.rest import ApiException
import kubernetes.config as config
if kubernetes_config is None:
kubernetes_config = dsl.TaskConfig()
if run_name is None:
run_name = f"kfp-{uuid.uuid4()}"
if train_job_patch is None:
train_job_patch = {}
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace",
"r") as ns_file:
namespace = ns_file.readline().strip()
# Command construction is handled by build_train_job
print("Loading Kubernetes configuration...")
try:
config.load_incluster_config()
print("Loaded in-cluster Kubernetes configuration")
except config.ConfigException:
config.load_kube_config()
print("Loaded kubeconfig Kubernetes configuration")
print("Creating Kubernetes API client...")
api_client = k8s_client.ApiClient()
custom_objects_api = k8s_client.CustomObjectsApi(api_client)
print("Successfully created Kubernetes API client")
print("Defining TrainJob resource...")
train_job = build_train_job(
runtime_ref=runtime_ref,
run_name=run_name,
namespace=namespace,
num_nodes=num_nodes,
kubernetes_config=kubernetes_config,
train_func=train_func,
packages_to_install=packages_to_install,
training_kwargs=kwargs,
train_job_patch=train_job_patch,
)
if kubernetes_config.affinity:
logging.warning("Affinity is not supported for training jobs")
print(f"TrainJob definition created:")
print(f" - Name: {run_name}")
print(f" - Namespace: {namespace}")
print("Submitting TrainJob to Kubernetes...")
try:
response = custom_objects_api.create_namespaced_custom_object(
group="trainer.kubeflow.org",
version="v1alpha1",
namespace=namespace,
plural="trainjobs",
body=train_job,
)
job_name = response["metadata"]["name"]
print(f"TrainJob {job_name} created successfully")
print(f"Response metadata: {response.get('metadata', {})}")
except ApiException as e:
print(f"Error creating TrainJob: {e}")
print(f"Error details: {e.body}")
print(f"Error status: {e.status}")
raise
# Start streaming logs from the main worker pod in the background
import threading
stop_log_stream_event = threading.Event()
log_thread = threading.Thread(
target=stream_main_worker_logs,
args=(api_client, namespace, run_name, stop_log_stream_event),
daemon=True,
)
print("Starting log streaming from main worker pod...")
log_thread.start()
print(f"Starting to monitor TrainJob {job_name} status...")
try:
monitor_train_job(custom_objects_api, namespace, job_name)
finally:
stop_log_stream_event.set()
try:
log_thread.join(timeout=10)
except Exception:
pass