How to set a threshold at inference time? #578
-
I would like to set a threshold at test/inference time, so that I can tune the sensitivity of the model. How would be possible to do that? |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 1 reply
-
You can set |
Beta Was this translation helpful? Give feedback.
-
@alexriedel1 this is not enough. Since if I set |
Beta Was this translation helpful? Give feedback.
-
Can you show me your testing / inferencing script? I'm pretty sure, calling As there is not |
Beta Was this translation helpful? Give feedback.
-
@alexriedel1 This is the full implementation of my script for inference on PatchCore model = get_model(config)
callbacks = get_callbacks(config)
trainer = Trainer(callbacks=callbacks, **config.trainer)
# Set custom threshold
model.adaptive_threshold = False
model.pixel_threshold.value = torch.tensor(float(CUSTOM_NUMBER))
model.image_threshold.value = torch.tensor(float(CUSTOM_NUMBER))
transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
dataset = InferenceDataset(
my_args['input'], image_size=tuple(config.dataset.image_size), transform_config=transform_config
)
dataloader = DataLoader(dataset)
trainer.predict(model=model, dataloaders=[dataloader]) At the end of this script my |
Beta Was this translation helpful? Give feedback.
-
I cannot reproduce the issue with a Patchcore Model
|
Beta Was this translation helpful? Give feedback.
-
@alexriedel1 Printing the custom thresholds with
This is a pip freeze of the env:
|
Beta Was this translation helpful? Give feedback.
-
@alevangel The thresholds are overwritten when Here is a snippet of the change to the def infer():
"""Run inference."""
args = get_args()
config = get_configurable_parameters(config_path=args.config)
# This is commented as setting this adds LoadModelCallback to the callback list
# config.trainer.resume_from_checkpoint = str(args.weights)
config.visualization.show_images = args.show
config.visualization.mode = args.visualization_mode
if args.output: # overwrite save path
config.visualization.save_images = True
config.visualization.image_save_path = args.output
else:
config.visualization.save_images = False
model = get_model(config)
model.load_state_dict(torch.load(args.weights)["state_dict"]) # manually load weights
callbacks = get_callbacks(config)
trainer = Trainer(callbacks=callbacks, **config.trainer)
transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
dataset = InferenceDataset(
args.input, image_size=tuple(config.dataset.image_size), transform_config=transform_config
)
dataloader = DataLoader(dataset)
model.adaptive_threshold = False
model.pixel_threshold.value = torch.tensor(float(1))
model.image_threshold.value = torch.tensor(float(1))
print(model.pixel_threshold.value, model.image_threshold.value)
trainer.predict(model=model, dataloaders=[dataloader])
print(model.pixel_threshold.value, model.image_threshold.value) |
Beta Was this translation helpful? Give feedback.
-
@ashwinvaidya17 Thanks, this will work on avoid threshold overwriting. model_state_dict = torch.load(my_args['weights'], map_location=device)["state_dict"]
model_state_dict.pop("normalization_metrics.min", None) # this is a key that doesn't make the model load the state
model_state_dict.pop("normalization_metrics.max", None) # this is a key that doesn't make the model load the state
model.load_state_dict(model_state_dict) But in this way the heatmap is not very normalized at all, even if the predicted mask is good. |
Beta Was this translation helpful? Give feedback.
@alevangel The thresholds are overwritten when
on_predict_start
is called fromLoadModelCallback
. One idea would be to manually calltorch.load
and popLoadModelCallback
from the callbacks list. Then updating the thresholds will work. This solution might work for your use case for now but we will try to come up with a better design.Here is a snippet of the change to the
infer
function oflightning_inference.py