Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Oct 15, 2024
1 parent c8f7ce5 commit f56a16c
Showing 1 changed file with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, write_calibration_table

# onnxruntime.set_default_logger_severity(0)

class ImageNetDataReader(CalibrationDataReader):
def __init__(self,
Expand Down Expand Up @@ -126,12 +127,13 @@ def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_
return: list of matrices characterizing multiple images
'''
def preprocess_images(input, channels=3, height=224, width=224):
image = input.resize((width, height), Image.ANTIALIAS)
image = input.resize((width, height), Image.Resampling.LANCZOS) # Image.ANTIALIAS was removed in Pillow 10.0.0
input_data = np.asarray(image).astype(np.float32)
if len(input_data.shape) != 2:
input_data = input_data.transpose([2, 0, 1])
else:
input_data = np.stack([input_data] * 3)
# image normalization
mean = np.array([0.079, 0.05, 0]) + 0.406
std = np.array([0.005, 0, 0.001]) + 0.224
for channel in range(input_data.shape[0]):
Expand All @@ -153,7 +155,8 @@ def preprocess_images(input, channels=3, height=224, width=224):

for image_name in batch_filenames:
image_filepath = images_folder + '/' + image_name
img = Image.open(image_filepath)
# Note: There is one image ILSVRC2012_val_00019877.JPEG which has 4 channels, so here we convert it to RGB with 3 channels for all images
img = Image.open(image_filepath).convert("RGB")
image_data = preprocess_images(img)
image_data = np.expand_dims(image_data, 0)
unconcatenated_batch_data.append(image_data)
Expand All @@ -163,7 +166,7 @@ def preprocess_images(input, channels=3, height=224, width=224):
return batch_data, batch_filenames, image_size_list

def get_synset_id(self, image_folder, offset, dataset_size):
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/devkit/data/meta.mat")
ilsvrc2012_meta = scipy.io.loadmat(image_folder + "/ILSVRC2012_devkit_t12/data/meta.mat")
id_to_synset = {}
for i in range(1000):
id = int(ilsvrc2012_meta["synsets"][i, 0][0][0][0])
Expand All @@ -178,7 +181,7 @@ def get_synset_id(self, image_folder, offset, dataset_size):
index = index + 1
file.close()

file = open(image_folder + "/devkit/data/ILSVRC2012_validation_ground_truth.txt", "r")
file = open(image_folder + "/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt", "r")
id = file.read().strip().split("\n")
id = list(map(int, id))
file.close()
Expand Down Expand Up @@ -318,11 +321,14 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
calibration_table_generation_enable = True # Enable/Disable INT8 calibration

# TensorRT EP INT8 settings
os.environ["ORT_TENSORRT_FP16_ENABLE"] = "1" # Enable FP16 precision
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "1" # Enable INT8 precision
os.environ["ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
os.environ["ORT_TENSORRT_ENGINE_CACHE_ENABLE"] = "1" # Enable engine caching
execution_provider = ["TensorrtExecutionProvider"]
execution_provider = [
('TensorrtExecutionProvider', {
'trt_int8_enable': True,
'trt_fp16_enable': True,
'trt_engine_cache_enable': True,
'trt_int8_calibration_table_name': 'calibration.flatbuffers', # The implicit quantization is deprecated in TRT 10
})
]

# Convert static batch to dynamic batch
[new_model_path, input_name] = convert_model_batch_to_dynamic(model_path)
Expand All @@ -343,7 +349,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
model_path=augmented_model_path,
input_name=input_name)
calibrator.collect_data(data_reader)
write_calibration_table(calibrator.compute_range())
write_calibration_table(calibrator.compute_data())

# Run prediction in Tensorrt EP
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
Expand Down

0 comments on commit f56a16c

Please sign in to comment.