@@ -52,20 +52,35 @@ class CompressedPairwiseConv:
52
52
# the 0 index is special: pconv[0] === 0.
53
53
pconv : np .ndarray
54
54
in_memory : bool = False
55
+ device : torch .device = torch .device ("cpu" )
55
56
56
57
def __post_init__ (self ):
57
58
assert self .shifts_a .ndim == self .shifts_b .ndim == 1
58
59
assert self .shifts_a .shape == (self .shifted_template_index_a .shape [1 ],)
59
- assert self .shifts_b .shape == (self .upsampled_shifted_template_index_b .shape [1 ],)
60
+ assert self .shifts_b .shape == (
61
+ self .upsampled_shifted_template_index_b .shape [1 ],
62
+ )
63
+
64
+ self .a_shift_offset , self .offset_shift_a_to_ix = _get_shift_indexer (
65
+ self .shifts_a
66
+ )
67
+ self .b_shift_offset , self .offset_shift_b_to_ix = _get_shift_indexer (
68
+ self .shifts_b
69
+ )
70
+
71
+ def get_shift_ix_a (self , shifts_a ):
72
+ return self .offset_shift_a_to_ix [shifts_a .to (int ) + self .a_shift_offset ]
73
+
74
+ def get_shift_ix_b (self , shifts_b ):
75
+ return self .offset_shift_b_to_ix [shifts_b .to (int ) + self .b_shift_offset ]
60
76
61
77
@classmethod
62
78
def from_h5 (cls , hdf5_filename , in_memory = True ):
63
- ff = [f for f in fields (cls ) if not f .name == "in_memory" ]
79
+ ff = [f for f in fields (cls ) if f .name not in ( "in_memory" , "device" ) ]
64
80
if in_memory :
65
81
with h5py .File (hdf5_filename , "r" ) as h5 :
66
82
data = {f .name : torch .from_numpy (h5 [f .name ][:]) for f in ff }
67
83
return cls (** data , in_memory = in_memory )
68
-
69
84
_h5 = h5py .File (hdf5_filename , "r" )
70
85
data = {}
71
86
for f in ff :
@@ -117,7 +132,7 @@ def from_template_data(
117
132
)
118
133
return cls .from_h5 (hdf5_filename )
119
134
120
- def at_shifts (self , shifts_a = None , shifts_b = None ):
135
+ def at_shifts (self , shifts_a = None , shifts_b = None , device = None ):
121
136
"""Subset this database to one set of shifts.
122
137
123
138
The database becomes shiftless (not in the pejorative sense).
@@ -133,8 +148,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
133
148
n_shifted_temps_a , n_up_shifted_temps_b = self .pconv_index .shape
134
149
135
150
# active shifted and upsampled indices
136
- shift_ix_a = torch . searchsorted ( self .shifts_a , shifts_a )
137
- shift_ix_b = torch . searchsorted ( self .shifts_b , shifts_b )
151
+ shift_ix_a = self .get_shift_ix_a ( shifts_a )
152
+ shift_ix_b = self .get_shift_ix_b ( shifts_b )
138
153
sub_shifted_temp_index_a = self .shifted_template_index_a [
139
154
torch .arange (len (self .shifted_template_index_a ))[:, None ],
140
155
shift_ix_a [:, None ],
@@ -166,6 +181,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
166
181
sub_pconv = self .pconv [sub_pconv_indices .to (self .pconv .device )]
167
182
else :
168
183
sub_pconv = torch .from_numpy (batched_h5_read (self .pconv , sub_pconv_indices ))
184
+ if device is not None :
185
+ sub_pconv = sub_pconv .to (device )
169
186
170
187
# reindexing
171
188
n_sub_shifted_temps_a = len (shifted_temp_ixs_a )
@@ -184,17 +201,30 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
184
201
pconv_index = sub_pconv_index ,
185
202
pconv = sub_pconv ,
186
203
in_memory = True ,
204
+ device = self .device ,
187
205
)
188
206
189
- def to (self , device = None , incl_pconv = False ):
207
+ def to (self , device = None , incl_pconv = False , pin = False ):
190
208
"""Become torch tensors on device."""
191
- for f in fields (self ):
192
- if f .name == "pconv" :
209
+ print (f"to { device = } " )
210
+ for name in ["offset_shift_a_to_ix" , "offset_shift_b_to_ix" ] + [
211
+ f .name for f in fields (self )
212
+ ]:
213
+ if name == "pconv" and not incl_pconv :
193
214
continue
194
- v = getattr (self , f . name )
215
+ v = getattr (self , name )
195
216
if isinstance (v , np .ndarray ) or torch .is_tensor (v ):
196
- setattr (self , f . name , torch .as_tensor (v , device = device ))
217
+ setattr (self , name , torch .as_tensor (v , device = device ))
197
218
self .device = device
219
+ if pin and self .device .type == "cuda" and torch .cuda .is_available () and not self .pconv .is_pinned ():
220
+ # self.pconv.share_memory_()
221
+ print ("pin" )
222
+ torch .cuda .cudart ().cudaHostRegister (
223
+ self .pconv .data_ptr (), self .pconv .numel () * self .pconv .element_size (), 0
224
+ )
225
+ # assert x.is_shared()
226
+ assert self .pconv .is_pinned ()
227
+ # self.pconv = self.pconv.pin_memory()
198
228
return self
199
229
200
230
def query (
@@ -211,9 +241,9 @@ def query(
211
241
device = None ,
212
242
):
213
243
if template_indices_a is None :
214
- template_indices_a = torch .arange (
215
- len (self .shifted_template_index_a ), device = self .device
216
- )
244
+ template_indices_a = torch .arange (
245
+ len (self .shifted_template_index_a ), device = self .device
246
+ )
217
247
template_indices_a = torch .atleast_1d (template_indices_a )
218
248
template_indices_b = torch .atleast_1d (template_indices_b )
219
249
@@ -230,8 +260,8 @@ def query(
230
260
shifted_template_index = shifted_template_index [:, 0 ]
231
261
upsampled_shifted_template_index = upsampled_shifted_template_index [:, 0 ]
232
262
else :
233
- shift_indices_a = torch . searchsorted ( self .shifts_a , shifts_a )
234
- shift_indices_b = torch . searchsorted ( self .shifts_b , shifts_b )
263
+ shift_indices_a = self .get_shift_ix_a ( shifts_a )
264
+ shift_indices_b = self .get_shift_ix_a ( shifts_b )
235
265
a_ix = (template_indices_a , shift_indices_a )
236
266
b_ix = (template_indices_b , shift_indices_b )
237
267
@@ -250,6 +280,9 @@ def query(
250
280
up_shifted_temp_ix_b = upsampled_shifted_template_index [b_ix ]
251
281
252
282
# return convolutions between all ai,bj or just ai,bi?
283
+ print (f"{ shifted_temp_ix_a .device = } { up_shifted_temp_ix_b .device = } " )
284
+ print (f"{ self .device = } { self .shifts_a .device = } " )
285
+ print (f"{ template_indices_a .device = } { template_indices_b .device = } " )
253
286
if grid :
254
287
pconv_indices = self .pconv_index [
255
288
shifted_temp_ix_a [:, None ], up_shifted_temp_ix_b [None , :]
@@ -258,9 +291,13 @@ def query(
258
291
template_indices_a , template_indices_b
259
292
).T
260
293
if scalings_b is not None :
261
- scalings_b = torch .broadcast_to (scalings_b [None ], pconv_indices .shape ).reshape (- 1 )
294
+ scalings_b = torch .broadcast_to (
295
+ scalings_b [None ], pconv_indices .shape
296
+ ).reshape (- 1 )
262
297
if times_b is not None :
263
- times_b = torch .broadcast_to (times_b [None ], pconv_indices .shape ).reshape (- 1 )
298
+ times_b = torch .broadcast_to (
299
+ times_b [None ], pconv_indices .shape
300
+ ).reshape (- 1 )
264
301
pconv_indices = pconv_indices .view (- 1 )
265
302
else :
266
303
pconv_indices = self .pconv_index [shifted_temp_ix_a , up_shifted_temp_ix_b ]
@@ -279,7 +316,9 @@ def query(
279
316
if self .in_memory :
280
317
pconvs = self .pconv [pconv_indices .to (self .pconv .device )]
281
318
else :
282
- pconvs = torch .from_numpy (batched_h5_read (self .pconv , pconv_indices .numpy (force = True )))
319
+ pconvs = torch .from_numpy (
320
+ batched_h5_read (self .pconv , pconv_indices .numpy (force = True ))
321
+ )
283
322
if device is not None :
284
323
pconvs = pconvs .to (device )
285
324
@@ -291,6 +330,7 @@ def query(
291
330
292
331
return template_indices_a , template_indices_b , pconvs
293
332
333
+
294
334
def batched_h5_read (dataset , indices , batch_size = 1000 ):
295
335
if indices .size < batch_size :
296
336
return dataset [indices ]
@@ -299,4 +339,19 @@ def batched_h5_read(dataset, indices, batch_size=1000):
299
339
for bs in range (0 , indices .size , batch_size ):
300
340
be = min (indices .size , bs + batch_size )
301
341
out [bs :be ] = dataset [indices [bs :be ]]
302
- return out
342
+ return out
343
+
344
+
345
+ def _get_shift_indexer (shifts ):
346
+ assert torch .equal (shifts , torch .sort (shifts ).values )
347
+ shift_offset = - int (shifts [0 ])
348
+ offset_shift_to_ix = []
349
+ for j , shift in enumerate (shifts ):
350
+ ix = shift + shift_offset
351
+ assert len (offset_shift_to_ix ) <= ix
352
+ assert 0 <= ix < len (shifts )
353
+ while len (offset_shift_to_ix ) < ix :
354
+ offset_shift_to_ix .append (len (shifts ))
355
+ offset_shift_to_ix .append (j )
356
+ offset_shift_to_ix = torch .tensor (offset_shift_to_ix , device = shifts .device )
357
+ return shift_offset , offset_shift_to_ix
0 commit comments