Skip to content

Commit

Permalink
Merge pull request #23 from stefanklut/half-precision-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut authored Feb 1, 2024
2 parents 37bb3ce + 493ecd1 commit 51f206a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
7 changes: 7 additions & 0 deletions configs/extra_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
_C.MODEL.SEM_SEG_HEAD = CN()
_C.MODEL.SEM_SEG_HEAD.WEIGHT = [1.0]

_C.MODEL.AMP_TRAIN = CN()
_C.MODEL.AMP_TRAIN.ENABLED = False
_C.MODEL.AMP_TRAIN.PRECISION = "bfloat16"
_C.MODEL.AMP_TEST = CN()
_C.MODEL.AMP_TEST.ENABLED = True
_C.MODEL.AMP_TEST.PRECISION = "bfloat16"

# Weights
_C.TRAIN = CN()
_C.TRAIN.WEIGHTS = ""
Expand Down
12 changes: 11 additions & 1 deletion core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,17 @@ def __init__(self, cfg: CfgNode):
data_loader = self.build_train_loader(cfg)

model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, data_loader, optimizer)
self._trainer = (AMPTrainer if cfg.MODEL.AMP_TRAIN.ENABLED else SimpleTrainer)(model, data_loader, optimizer)
if isinstance(self._trainer, AMPTrainer):
precision_converter = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
precision = precision_converter.get(cfg.AMP_TRAIN.PRECISION, None)
if precision is None:
raise ValueError(f"Unrecognized precision: {cfg.AMP_TRAIN.PRECISION}")
self._trainer.precision = precision

self.scheduler = self.build_lr_scheduler(cfg, optimizer)

Expand Down
74 changes: 63 additions & 11 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ def get_arguments() -> argparse.Namespace:
detectron2_args.add_argument("--opts", nargs="+", help="optional args to change", action="extend", default=[])

io_args = parser.add_argument_group("IO")
io_args.add_argument("-i", "--input", nargs="+", help="Input folder", type=str, action="extend", required=True)
io_args.add_argument(
"-i",
"--input",
nargs="+",
help="Input folder",
type=str,
action="extend",
required=True,
)
io_args.add_argument("-o", "--output", help="Output folder", type=str, required=True)

parser.add_argument("-w", "--whitelist", nargs="+", help="Input folder", type=str, action="extend")
Expand All @@ -58,13 +66,23 @@ def __init__(self, cfg):
cfg (CfgNode): config
"""
self.cfg = cfg.clone() # cfg can be modified by model

self.model = build_model(self.cfg)
self.model.eval()

if len(cfg.DATASETS.TEST):
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])

self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
precision_converter = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
self.precision = precision_converter.get(cfg.MODEL.AMP_TEST.PRECISION, None)
if self.precision is None:
raise ValueError(f"Unrecognized precision: {cfg.MODEL.AMP_TEST.PRECISION}")

assert self.cfg.INPUT.FORMAT in ["RGB", "BGR"], self.cfg.INPUT.FORMAT

checkpointer = DetectionCheckpointer(self.model)
if not cfg.TEST.WEIGHTS:
Expand All @@ -73,7 +91,8 @@ def __init__(self, cfg):
checkpointer.load(cfg.TEST.WEIGHTS)

if cfg.INPUT.RESIZE_MODE == "none":
self.aug = ResizeScaling(scale=1) # HACK percentage of 1 is no scaling
# HACK percentage of 1 is no scaling
self.aug = ResizeScaling(scale=1)
elif cfg.INPUT.RESIZE_MODE in ["shortest_edge", "longest_edge"]:
if cfg.INPUT.RESIZE_MODE == "shortest_edge":
self.aug = ResizeShortestEdge(cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, "choice")
Expand Down Expand Up @@ -102,7 +121,10 @@ def get_image_size(self, height: int, width: int) -> tuple[int, int]:
new_height, new_width = height, width
elif self.cfg.INPUT.RESIZE_MODE in ["shortest_edge", "longest_edge"]:
new_height, new_width = self.aug.get_output_shape(
height, width, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST
height,
width,
self.cfg.INPUT.MIN_SIZE_TEST,
self.cfg.INPUT.MAX_SIZE_TEST,
)
elif self.cfg.INPUT.RESIZE_MODE == "scaling":
new_height, new_width = self.aug.get_output_shape(
Expand All @@ -128,7 +150,8 @@ def gpu_call(self, original_image: torch.Tensor):
channels, height, width = original_image.shape
assert channels == 3, f"Must be a BGR image, found {channels} channels"
image = torch.as_tensor(original_image, dtype=torch.float32, device=self.cfg.MODEL.DEVICE)
if self.input_format == "BGR":

if self.cfg.INPUT.FORMAT == "BGR":
# whether the model expects BGR inputs or RGB
image = image[[2, 1, 0], :, :]

Expand All @@ -138,7 +161,17 @@ def gpu_call(self, original_image: torch.Tensor):
image = torch.nn.functional.interpolate(image[None], mode="bilinear", size=(new_height, new_width))[0]

inputs = {"image": image, "height": new_height, "width": new_width}
predictions = self.model([inputs])[0]

with torch.autocast(
device_type=self.cfg.MODEL.DEVICE,
enabled=self.cfg.MODEL.AMP_TEST.ENABLED,
dtype=self.precision,
):
predictions = self.model([inputs])[0]

# if torch.isnan(predictions["sem_seg"]).any():
# raise ValueError("NaN in predictions")

return predictions, height, width

def cpu_call(self, original_image: np.ndarray):
Expand All @@ -157,12 +190,22 @@ def cpu_call(self, original_image: np.ndarray):
assert channels == 3, f"Must be a RBG image, found {channels} channels"
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image, dtype=torch.float32, device=self.cfg.MODEL.DEVICE).permute(2, 0, 1)
if self.input_format == "BGR":

if self.cfg.INPUT.FORMAT == "BGR":
# whether the model expects BGR inputs or RGB
image = image[[2, 1, 0], :, :]

inputs = {"image": image, "height": image.shape[1], "width": image.shape[2]}
predictions = self.model([inputs])[0]

with torch.autocast(
device_type=self.cfg.MODEL.DEVICE,
enabled=self.cfg.MODEL.AMP_TEST.ENABLED,
dtype=self.precision,
):
predictions = self.model([inputs])[0]

# if torch.isnan(predictions["sem_seg"]).any():
# raise ValueError("NaN in predictions")

return predictions, height, width

Expand Down Expand Up @@ -205,7 +248,11 @@ def __getitem__(self, index):
def collate_numpy(batch):
collate_map = default_collate_fn_map

def new_map(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
def new_map(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
):
return batch

collate_map.update({np.ndarray: new_map, type(None): new_map})
Expand Down Expand Up @@ -324,7 +371,12 @@ def process(self):

dataset = LoadingDataset(self.input_paths)
dataloader = DataLoader(
dataset, shuffle=False, batch_size=None, num_workers=16, pin_memory=False, collate_fn=collate_numpy
dataset,
shuffle=False,
batch_size=None,
num_workers=16,
pin_memory=False,
collate_fn=collate_numpy,
)
for inputs in tqdm(dataloader, desc="Predicting PageXML"):
self.save_prediction(inputs[0], inputs[1])
Expand Down

0 comments on commit 51f206a

Please sign in to comment.