1
1
import math
2
2
from abc import ABCMeta , abstractmethod
3
- from pathlib import Path
4
- from typing import Union
3
+ from typing import Type , Annotated
5
4
6
5
import numpy as np
7
- import rasterio
8
6
import torch
9
7
from rasterio .windows import Window
10
8
@@ -103,24 +101,21 @@ def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
103
101
class TorchMemoryRegister (object ):
104
102
def __init__ (
105
103
self ,
106
- image_path : Union [str , Path ],
107
- reg_depth : int ,
108
- window_size : int ,
104
+ image_width : Annotated [int , "Width of the image in pixels" ],
105
+ register_depth : Annotated [int , "Generally equal to the number of classes" ],
106
+ window_size : Annotated [int , "Moving window size" ],
107
+ kernel : Type [Kernel ],
109
108
device : torch .device .type ,
110
109
):
111
110
super ().__init__ ()
112
- self .image_path = Path (image_path )
113
- self .n = reg_depth
111
+ self .n = register_depth
114
112
self .ws = window_size
115
113
self .hws = window_size // 2
114
+ self .kernel = kernel (size = window_size , device = device )
116
115
self .device = device
117
116
118
- # Copy metadata from img
119
- with rasterio .open (str (image_path ), "r" ) as src :
120
- src_width = src .width
121
-
122
117
self .height = self .ws
123
- self .width = (math .ceil (src_width / self .ws ) * self .ws ) + self .hws
118
+ self .width = (math .ceil (image_width / self .ws ) * self .ws ) + self .hws
124
119
self .register = torch .zeros (
125
120
(self .n , self .height , self .width ), device = self .device
126
121
)
@@ -131,40 +126,110 @@ def _zero_chip(self):
131
126
(self .n , self .hws , self .hws ), dtype = torch .float , device = self .device
132
127
)
133
128
134
- def step (self , new_logits : torch .Tensor , img_window : Window ):
135
- # 1. Read data from the registry to update with the new logits
129
+ def step (
130
+ self ,
131
+ new_logits : torch .Tensor ,
132
+ img_window : Window ,
133
+ * ,
134
+ top : bool ,
135
+ bottom : bool ,
136
+ left : bool ,
137
+ right : bool ,
138
+ ):
139
+ # Read data from the registry to update with the new logits
136
140
# |a|b| |
137
141
# |c|d| |
138
142
with torch .no_grad ():
139
143
logits_abcd = self .register [
140
144
:, :, img_window .col_off : img_window .col_off + self .ws
141
145
].clone ()
142
- logits_abcd += new_logits
143
-
144
- # Update the registry and pop information-complete data
145
- # |c|b| | + pop a
146
- # |0|d| |
147
- logits_a = logits_abcd [:, : self .hws , : self .hws ]
148
- logits_c = logits_abcd [:, self .hws :, : self .hws ]
149
- logits_c0 = torch .concat ([logits_c , self ._zero_chip ], dim = 1 )
150
- logits_bd = logits_abcd [:, :, self .hws :]
151
-
152
- # write c0
153
- self .register [:, :, img_window .col_off : img_window .col_off + self .hws ] = (
154
- logits_c0
155
- )
156
-
157
- # write bd
158
- col_off_bd = img_window .col_off + self .hws
159
- self .register [:, :, col_off_bd : col_off_bd + (self .ws - self .hws )] = logits_bd
160
-
161
- # Return the information-complete predictions
162
- logits_win = Window (
163
- col_off = img_window .col_off ,
164
- row_off = img_window .row_off ,
165
- height = min (self .hws , img_window .height ),
166
- width = min (self .hws , img_window .width ),
167
- )
168
- logits = logits_a [:, : img_window .height , : img_window .width ]
146
+ logits_abcd += self .kernel (
147
+ new_logits , top = top , bottom = bottom , left = left , right = right
148
+ )
149
+
150
+ if right and bottom :
151
+ # Need to return entire window
152
+ logits_win = img_window
153
+ logits = logits_abcd [:, : img_window .height , : img_window .width ]
154
+
155
+ elif right :
156
+ # Need to return a and b sections
157
+
158
+ # Update the registry and pop information-complete data
159
+ # |c|d| | + pop a+b
160
+ # |0|0| |
161
+ logits_ab = logits_abcd [:, : self .hws , :]
162
+ logits_cd = logits_abcd [:, self .hws :, :]
163
+ logits_00 = torch .concat ([self ._zero_chip , self ._zero_chip ], dim = 2 )
164
+
165
+ # write cd and 00
166
+ self .register [
167
+ :, : self .hws , img_window .col_off : img_window .col_off + self .ws
168
+ ] = logits_cd
169
+ self .register [
170
+ :, self .hws :, img_window .col_off : img_window .col_off + self .ws
171
+ ] = logits_00
172
+
173
+ logits_win = Window (
174
+ col_off = img_window .col_off ,
175
+ row_off = img_window .row_off ,
176
+ height = min (self .hws , img_window .height ),
177
+ width = min (self .ws , img_window .width ),
178
+ )
179
+ logits = logits_ab [:, : logits_win .height , : logits_win .width ]
180
+ elif bottom :
181
+ # Need to return a and c sections only
182
+
183
+ # Update the registry and pop information-complete data
184
+ # |0|b| | + pop a+c
185
+ # |0|d| |
186
+ logits_ac = logits_abcd [:, :, : self .hws ]
187
+ logits_00 = torch .concat ([self ._zero_chip , self ._zero_chip ], dim = 1 )
188
+ logits_bd = logits_abcd [:, :, self .hws :]
189
+
190
+ # write 00 and bd
191
+ self .register [:, :, img_window .col_off : img_window .col_off + self .hws ] = (
192
+ logits_00 # Not really necessary since this is the last row
193
+ )
194
+ self .register [
195
+ :, :, img_window .col_off + self .hws : img_window .col_off + self .ws
196
+ ] = logits_bd
197
+
198
+ logits_win = Window (
199
+ col_off = img_window .col_off ,
200
+ row_off = img_window .row_off ,
201
+ height = min (self .ws , img_window .height ),
202
+ width = min (self .hws , img_window .width ),
203
+ )
204
+ logits = logits_ac [:, : img_window .height , : img_window .width ]
205
+ else :
206
+ # Need to return "a" section only
207
+
208
+ # Update the registry and pop information-complete data
209
+ # |c|b| | + pop a
210
+ # |0|d| |
211
+ logits_a = logits_abcd [:, : self .hws , : self .hws ]
212
+ logits_c = logits_abcd [:, self .hws :, : self .hws ]
213
+ logits_c0 = torch .concat ([logits_c , self ._zero_chip ], dim = 1 )
214
+ logits_bd = logits_abcd [:, :, self .hws :]
215
+
216
+ # write c0
217
+ self .register [:, :, img_window .col_off : img_window .col_off + self .hws ] = (
218
+ logits_c0
219
+ )
220
+
221
+ # write bd
222
+ col_off_bd = img_window .col_off + self .hws
223
+ self .register [:, :, col_off_bd : col_off_bd + (self .ws - self .hws )] = (
224
+ logits_bd
225
+ )
226
+
227
+ logits_win = Window (
228
+ col_off = img_window .col_off ,
229
+ row_off = img_window .row_off ,
230
+ height = min (self .hws , img_window .height ),
231
+ width = min (self .hws , img_window .width ),
232
+ )
233
+ logits = logits_a [:, : img_window .height , : img_window .width ]
169
234
170
235
return logits , logits_win
0 commit comments