diff --git a/api/flask_app.py b/api/flask_app.py index 9f56d7b..1c56fd1 100644 --- a/api/flask_app.py +++ b/api/flask_app.py @@ -12,7 +12,7 @@ from flask import Flask, Response, abort, jsonify, request from prometheus_client import Counter, Gauge, generate_latest -sys.path.append(str(Path(__file__).resolve().parent.joinpath(".."))) +sys.path.append(str(Path(__file__).resolve().parent.joinpath(".."))) # noqa: E402 from datasets.mapper import AugInput from main import setup_cfg, setup_logging from page_xml.output_pageXML import OutputPageXML @@ -123,6 +123,24 @@ def setup_model(self, model_name: str, args: DummyArgs): exception_predict_counter = Counter("exception_predict", "Exception thrown in predict() function") +def safe_predict(data, device): + """ + Attempt to predict on the specified device, falling back to CPU on OOM. + """ + + try: + return predict_gen_page_wrapper.predictor(data, device) + except Exception as exception: + # Catch CUDA out of memory errors + if isinstance(exception, torch.cuda.OutOfMemoryError) or ( + isinstance(exception, RuntimeError) and "NVML_SUCCESS == r INTERNAL ASSERT FAILED" in str(exception) + ): + logger.warning("CUDA OOM encountered, falling back to CPU.") + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + return predict_gen_page_wrapper.predictor(data, "cpu") + + def predict_image( image: np.ndarray, dpi: Optional[int], @@ -172,7 +190,7 @@ def predict_image( manual_dpi=predict_gen_page_wrapper.predictor.cfg.INPUT.DPI.MANUAL_DPI_TEST, ) - outputs = predict_gen_page_wrapper.predictor(data) + outputs = safe_predict(data, device=predict_gen_page_wrapper.predictor.cfg.MODEL.DEVICE) output_image = outputs[0]["sem_seg"] # output_image = torch.argmax(outputs[0]["sem_seg"], dim=-3).cpu().numpy() diff --git a/run.py b/run.py index 84f8312..64eb19f 100644 --- a/run.py +++ b/run.py @@ -133,7 +133,7 @@ def __init__(self, cfg: CfgNode): # return predictions, height, width - def cpu_call(self, data: AugInput) -> tuple[dict, int, int]: + def cpu_call(self, data: AugInput, device: str = None) -> tuple[dict, int, int]: """ Run the model on the image with preprocessing on the cpu @@ -143,13 +143,19 @@ def cpu_call(self, data: AugInput) -> tuple[dict, int, int]: Returns: tuple[dict, int, int]: predictions, height, width """ + logger = logging.getLogger(get_logger_name()) + + # Default value of device should be the one in the config + if device is None: + device = self.cfg.MODEL.DEVICE + with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 # Apply pre-processing to image. height, width, channels = data.image.shape assert channels == 3, f"Must be a RBG image, found {channels} channels" transform = self.aug(data) - image = torch.as_tensor(data.image, dtype=torch.float32, device=self.cfg.MODEL.DEVICE).permute(2, 0, 1) + image = torch.as_tensor(data.image, dtype=torch.float32, device=device).permute(2, 0, 1) if self.cfg.INPUT.FORMAT == "BGR": # whether the model expects BGR inputs or RGB @@ -157,11 +163,18 @@ def cpu_call(self, data: AugInput) -> tuple[dict, int, int]: inputs = {"image": image, "height": image.shape[1], "width": image.shape[2]} + # If we predict on CPU, use full precision + precision = self.precision if device != "cpu" else torch.float32 + with torch.autocast( - device_type=self.cfg.MODEL.DEVICE, + device_type=device, enabled=self.cfg.MODEL.AMP_TEST.ENABLED, - dtype=self.precision, + dtype=precision, ): + if next(self.model.parameters()).device != device: + logger.info(f"Moving model to {device} device") + self.model.to(device) + predictions = self.model([inputs])[0] # if torch.isnan(predictions["sem_seg"]).any(): @@ -169,7 +182,7 @@ def cpu_call(self, data: AugInput) -> tuple[dict, int, int]: return predictions, height, width - def __call__(self, data: AugInput): + def __call__(self, data: AugInput, device: str = None): """ Run the model on the image with preprocessing @@ -185,7 +198,7 @@ def __call__(self, data: AugInput): # return self.gpu_call(original_image) # else: # raise TypeError(f"Unknown image type: {type(original_image)}") - return self.cpu_call(data) + return self.cpu_call(data, device) class LoadingDataset(Dataset):