Skip to content

Commit

Permalink
Merge pull request #32 from stefanklut/oom-to-cpu
Browse files Browse the repository at this point in the history
Inference in API on CPU when encountering OOM
  • Loading branch information
stefanklut authored Mar 28, 2024
2 parents bdea810 + 36460ca commit ed77518
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
29 changes: 27 additions & 2 deletions api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flask import Flask, Response, abort, jsonify, request
from prometheus_client import Counter, Gauge, generate_latest

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
sys.path.append(str(Path(__file__).resolve().parent.joinpath(".."))) # noqa: E402
from datasets.mapper import AugInput
from main import setup_cfg, setup_logging
from page_xml.output_pageXML import OutputPageXML
Expand Down Expand Up @@ -123,6 +123,31 @@ def setup_model(self, model_name: str, args: DummyArgs):
exception_predict_counter = Counter("exception_predict", "Exception thrown in predict() function")


def safe_predict(data, device):
"""
Attempt to predict on the speci
Args:
data: Data to predict on
device: Device to predict on
Returns:
Prediction output
"""

try:
return predict_gen_page_wrapper.predictor(data, device)
except Exception as exception:
# Catch CUDA out of memory errors
if isinstance(exception, torch.cuda.OutOfMemoryError) or (
isinstance(exception, RuntimeError) and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(exception)
):
logger.warning("CUDA OOM encountered, falling back to CPU.")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
return predict_gen_page_wrapper.predictor(data, "cpu")


def predict_image(
image: np.ndarray,
dpi: Optional[int],
Expand Down Expand Up @@ -172,7 +197,7 @@ def predict_image(
manual_dpi=predict_gen_page_wrapper.predictor.cfg.INPUT.DPI.MANUAL_DPI_TEST,
)

outputs = predict_gen_page_wrapper.predictor(data)
outputs = safe_predict(data, device=predict_gen_page_wrapper.predictor.cfg.MODEL.DEVICE)

output_image = outputs[0]["sem_seg"]
# output_image = torch.argmax(outputs[0]["sem_seg"], dim=-3).cpu().numpy()
Expand Down
27 changes: 21 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,49 +133,64 @@ def __init__(self, cfg: CfgNode):

# return predictions, height, width

def cpu_call(self, data: AugInput) -> tuple[dict, int, int]:
def cpu_call(self, data: AugInput, device: str = None) -> tuple[dict, int, int]:
"""
Run the model on the image with preprocessing on the cpu
Args:
data (AugInput): image to run the model on
device (str): device to run the model on
Returns:
tuple[dict, int, int]: predictions, height, width
"""
logger = logging.getLogger(get_logger_name())

# Default value of device should be the one in the config
if device is None:
device = self.cfg.MODEL.DEVICE

with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.

height, width, channels = data.image.shape
assert channels == 3, f"Must be a RBG image, found {channels} channels"
# In place augmentation
transform = self.aug(data)
image = torch.as_tensor(data.image, dtype=torch.float32, device=self.cfg.MODEL.DEVICE).permute(2, 0, 1)
image = torch.as_tensor(data.image, dtype=torch.float32, device=device).permute(2, 0, 1)

if self.cfg.INPUT.FORMAT == "BGR":
# whether the model expects BGR inputs or RGB
image = image[[2, 1, 0], :, :]

inputs = {"image": image, "height": image.shape[1], "width": image.shape[2]}

# If we predict on CPU, use full precision
precision = self.precision if device != "cpu" else torch.float32

with torch.autocast(
device_type=self.cfg.MODEL.DEVICE,
device_type=device,
enabled=self.cfg.MODEL.AMP_TEST.ENABLED,
dtype=self.precision,
dtype=precision,
):
if next(self.model.parameters()).device != device:
logger.info(f"Moving model to {device} device")
self.model.to(device)

predictions = self.model([inputs])[0]

# if torch.isnan(predictions["sem_seg"]).any():
# raise ValueError("NaN in predictions")

return predictions, height, width

def __call__(self, data: AugInput):
def __call__(self, data: AugInput, device: str = None) -> tuple[dict, int, int]:
"""
Run the model on the image with preprocessing
Args:
data (AugInput): image to run the model on
device (str): device to run the model on
Returns:
tuple[dict, int, int]: predictions, height, width
"""
Expand All @@ -186,7 +201,7 @@ def __call__(self, data: AugInput):
# return self.gpu_call(original_image)
# else:
# raise TypeError(f"Unknown image type: {type(original_image)}")
return self.cpu_call(data)
return self.cpu_call(data, device)


class LoadingDataset(Dataset):
Expand Down

0 comments on commit ed77518

Please sign in to comment.