Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/xturing/trainers/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import pytorch_lightning as pl
import torch
from deepspeed.ops.adam import DeepSpeedCPUAdam

try:
from deepspeed.ops.adam import DeepSpeedCPUAdam
except ModuleNotFoundError as import_err: # pragma: no cover - optional dependency
DeepSpeedCPUAdam = None
_DEEPSPEED_IMPORT_ERROR = import_err
else:
_DEEPSPEED_IMPORT_ERROR = None
from pytorch_lightning import callbacks
from pytorch_lightning.loggers import Logger
from pytorch_lightning.trainer.trainer import Trainer
Expand Down Expand Up @@ -51,6 +58,11 @@ def configure_optimizers(self):
self.pytorch_model.parameters(), lr=self.learning_rate
)
elif self.optimizer_name == "cpu_adam":
if DeepSpeedCPUAdam is None:
raise ModuleNotFoundError(
"DeepSpeed is required for optimizer 'cpu_adam'. "
"Install it with `pip install deepspeed`."
) from _DEEPSPEED_IMPORT_ERROR
optimizer = DeepSpeedCPUAdam(
self.pytorch_model.parameters(), lr=self.learning_rate
)
Expand Down Expand Up @@ -164,13 +176,17 @@ def __init__(
]

strategy = "auto"
if not IS_INTERACTIVE:
strategy = (
"deepspeed_stage_2_offload"
if optimizer_name == "cpu_adam"
else "deepspeed_stage_2"
)

if use_deepspeed:
if DeepSpeedCPUAdam is None:
raise ModuleNotFoundError(
"use_deepspeed=True requires DeepSpeed. Install it with `pip install deepspeed`."
) from _DEEPSPEED_IMPORT_ERROR
if not IS_INTERACTIVE:
strategy = (
"deepspeed_stage_2_offload"
if optimizer_name == "cpu_adam"
else "deepspeed_stage_2"
)
self.trainer = Trainer(
num_nodes=1,
accelerator="gpu",
Expand Down