Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference in API on CPU when encountering OOM #32

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flask import Flask, Response, abort, jsonify, request
from prometheus_client import Counter, Gauge, generate_latest

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
sys.path.append(str(Path(__file__).resolve().parent.joinpath(".."))) # noqa: E402
from datasets.mapper import AugInput
from main import setup_cfg, setup_logging
from page_xml.output_pageXML import OutputPageXML
Expand Down Expand Up @@ -123,6 +123,31 @@ def setup_model(self, model_name: str, args: DummyArgs):
exception_predict_counter = Counter("exception_predict", "Exception thrown in predict() function")


def safe_predict(data, device):
"""
Attempt to predict on the speci

Args:
data: Data to predict on
device: Device to predict on

Returns:
Prediction output
"""

try:
return predict_gen_page_wrapper.predictor(data, device)
except Exception as exception:
# Catch CUDA out of memory errors
if isinstance(exception, torch.cuda.OutOfMemoryError) or (
isinstance(exception, RuntimeError) and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(exception)
):
logger.warning("CUDA OOM encountered, falling back to CPU.")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
return predict_gen_page_wrapper.predictor(data, "cpu")


def predict_image(
image: np.ndarray,
dpi: Optional[int],
Expand Down Expand Up @@ -172,7 +197,7 @@ def predict_image(
manual_dpi=predict_gen_page_wrapper.predictor.cfg.INPUT.DPI.MANUAL_DPI_TEST,
)

outputs = predict_gen_page_wrapper.predictor(data)
outputs = safe_predict(data, device=predict_gen_page_wrapper.predictor.cfg.MODEL.DEVICE)

output_image = outputs[0]["sem_seg"]
# output_image = torch.argmax(outputs[0]["sem_seg"], dim=-3).cpu().numpy()
Expand Down
27 changes: 21 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,48 +133,63 @@ def __init__(self, cfg: CfgNode):

# return predictions, height, width

def cpu_call(self, data: AugInput) -> tuple[dict, int, int]:
def cpu_call(self, data: AugInput, device: str = None) -> tuple[dict, int, int]:
"""
Run the model on the image with preprocessing on the cpu

Args:
data (AugInput): image to run the model on
device (str): device to run the model on

Returns:
tuple[dict, int, int]: predictions, height, width
"""
logger = logging.getLogger(get_logger_name())

# Default value of device should be the one in the config
if device is None:
device = self.cfg.MODEL.DEVICE

with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.

height, width, channels = data.image.shape
assert channels == 3, f"Must be a RBG image, found {channels} channels"
transform = self.aug(data)
image = torch.as_tensor(data.image, dtype=torch.float32, device=self.cfg.MODEL.DEVICE).permute(2, 0, 1)
image = torch.as_tensor(data.image, dtype=torch.float32, device=device).permute(2, 0, 1)

if self.cfg.INPUT.FORMAT == "BGR":
# whether the model expects BGR inputs or RGB
image = image[[2, 1, 0], :, :]

inputs = {"image": image, "height": image.shape[1], "width": image.shape[2]}

# If we predict on CPU, use full precision
precision = self.precision if device != "cpu" else torch.float32

with torch.autocast(
device_type=self.cfg.MODEL.DEVICE,
device_type=device,
enabled=self.cfg.MODEL.AMP_TEST.ENABLED,
dtype=self.precision,
dtype=precision,
):
if next(self.model.parameters()).device != device:
logger.info(f"Moving model to {device} device")
self.model.to(device)

predictions = self.model([inputs])[0]

# if torch.isnan(predictions["sem_seg"]).any():
# raise ValueError("NaN in predictions")

return predictions, height, width

def __call__(self, data: AugInput):
def __call__(self, data: AugInput, device: str = None) -> tuple[dict, int, int]:
"""
Run the model on the image with preprocessing

Args:
data (AugInput): image to run the model on
device (str): device to run the model on
Returns:
tuple[dict, int, int]: predictions, height, width
"""
Expand All @@ -185,7 +200,7 @@ def __call__(self, data: AugInput):
# return self.gpu_call(original_image)
# else:
# raise TypeError(f"Unknown image type: {type(original_image)}")
return self.cpu_call(data)
return self.cpu_call(data, device)


class LoadingDataset(Dataset):
Expand Down
Loading