-
Notifications
You must be signed in to change notification settings - Fork 6
/
calibrator.py
62 lines (48 loc) · 1.97 KB
/
calibrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
from cuda import cudart
import tensorrt as trt
ioFile="benchmark/data/calibration.npy"
ioData = np.load(ioFile,allow_pickle=True).item()
in_tensor=ioData['in_tensor']
class MobileVitCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, calibrationCount=25, inputShape=[4,3,256,256],cacheFile='./target/mobilevit.cacheFile'):
trt.IInt8EntropyCalibrator2.__init__(self)
self.shape = inputShape
self.cacheFile = cacheFile
self.calibrationCount = calibrationCount
self.buffeSize = trt.volume(inputShape) * trt.float32.itemsize
self.dIn =[]
self.dIn.append(cudart.cudaMalloc(self.buffeSize)[1])
self.count = 0
def __del__(self):
cudart.cudaFree(self.dIn[0])
def get_batch_size(self): # do NOT change name
return self.shape[0]
def get_batch(self, nameList=None): # do NOT change name
if self.count < self.calibrationCount:
start_idx = self.count*self.shape[0]
end_idx = start_idx + self.shape[0]
in_data=in_tensor[start_idx:end_idx,...]
cudart.cudaMemcpy(self.dIn[0], in_data.ctypes.data, self.buffeSize, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)
self.count += 1
return self.dIn
else:
return None
def read_calibration_cache(self): # do NOT change name
if os.path.exists(self.cacheFile):
print("Succeed finding cahce file: %s" % (self.cacheFile))
with open(self.cacheFile, "rb") as f:
cache = f.read()
return cache
else:
print("Failed finding int8 cache!")
return
def write_calibration_cache(self, cache): # do NOT change name
with open(self.cacheFile, "wb") as f:
f.write(cache)
print("Succeed saving int8 cache!")
if __name__ == "__main__":
cudart.cudaDeviceSynchronize()
m = MobileVitCalibrator()
m.get_batch("ttt")