Skip to content

Commit cf67c9b

Browse files
committed
Fix bug with windowed processing
1 parent 931225a commit cf67c9b

File tree

5 files changed

+324
-106
lines changed

5 files changed

+324
-106
lines changed

kelp_o_matic/geotiff_io/geotiff_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def __init__(
4444
self.profile = src.profile
4545
self.block_shapes = src.block_shapes
4646

47-
self._y0s = list(range(0, self.height, self.stride))
48-
self._x0s = list(range(0, self.width, self.stride))
47+
self._y0s = list(range(0, self.height - self.stride, self.stride))
48+
self._x0s = list(range(0, self.width - self.stride, self.stride))
4949
self.y0x0 = list(itertools.product(self._y0s, self._x0s))
5050

5151
def __len__(self) -> int:

kelp_o_matic/hann.py

Lines changed: 107 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import math
22
from abc import ABCMeta, abstractmethod
3-
from pathlib import Path
4-
from typing import Union
3+
from typing import Type, Annotated
54

65
import numpy as np
7-
import rasterio
86
import torch
97
from rasterio.windows import Window
108

@@ -103,24 +101,21 @@ def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
103101
class TorchMemoryRegister(object):
104102
def __init__(
105103
self,
106-
image_path: Union[str, Path],
107-
reg_depth: int,
108-
window_size: int,
104+
image_width: Annotated[int, "Width of the image in pixels"],
105+
register_depth: Annotated[int, "Generally equal to the number of classes"],
106+
window_size: Annotated[int, "Moving window size"],
107+
kernel: Type[Kernel],
109108
device: torch.device.type,
110109
):
111110
super().__init__()
112-
self.image_path = Path(image_path)
113-
self.n = reg_depth
111+
self.n = register_depth
114112
self.ws = window_size
115113
self.hws = window_size // 2
114+
self.kernel = kernel(size=window_size, device=device)
116115
self.device = device
117116

118-
# Copy metadata from img
119-
with rasterio.open(str(image_path), "r") as src:
120-
src_width = src.width
121-
122117
self.height = self.ws
123-
self.width = (math.ceil(src_width / self.ws) * self.ws) + self.hws
118+
self.width = (math.ceil(image_width / self.ws) * self.ws) + self.hws
124119
self.register = torch.zeros(
125120
(self.n, self.height, self.width), device=self.device
126121
)
@@ -131,40 +126,110 @@ def _zero_chip(self):
131126
(self.n, self.hws, self.hws), dtype=torch.float, device=self.device
132127
)
133128

134-
def step(self, new_logits: torch.Tensor, img_window: Window):
135-
# 1. Read data from the registry to update with the new logits
129+
def step(
130+
self,
131+
new_logits: torch.Tensor,
132+
img_window: Window,
133+
*,
134+
top: bool,
135+
bottom: bool,
136+
left: bool,
137+
right: bool,
138+
):
139+
# Read data from the registry to update with the new logits
136140
# |a|b| |
137141
# |c|d| |
138142
with torch.no_grad():
139143
logits_abcd = self.register[
140144
:, :, img_window.col_off : img_window.col_off + self.ws
141145
].clone()
142-
logits_abcd += new_logits
143-
144-
# Update the registry and pop information-complete data
145-
# |c|b| | + pop a
146-
# |0|d| |
147-
logits_a = logits_abcd[:, : self.hws, : self.hws]
148-
logits_c = logits_abcd[:, self.hws :, : self.hws]
149-
logits_c0 = torch.concat([logits_c, self._zero_chip], dim=1)
150-
logits_bd = logits_abcd[:, :, self.hws :]
151-
152-
# write c0
153-
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
154-
logits_c0
155-
)
156-
157-
# write bd
158-
col_off_bd = img_window.col_off + self.hws
159-
self.register[:, :, col_off_bd : col_off_bd + (self.ws - self.hws)] = logits_bd
160-
161-
# Return the information-complete predictions
162-
logits_win = Window(
163-
col_off=img_window.col_off,
164-
row_off=img_window.row_off,
165-
height=min(self.hws, img_window.height),
166-
width=min(self.hws, img_window.width),
167-
)
168-
logits = logits_a[:, : img_window.height, : img_window.width]
146+
logits_abcd += self.kernel(
147+
new_logits, top=top, bottom=bottom, left=left, right=right
148+
)
149+
150+
if right and bottom:
151+
# Need to return entire window
152+
logits_win = img_window
153+
logits = logits_abcd[:, : img_window.height, : img_window.width]
154+
155+
elif right:
156+
# Need to return a and b sections
157+
158+
# Update the registry and pop information-complete data
159+
# |c|d| | + pop a+b
160+
# |0|0| |
161+
logits_ab = logits_abcd[:, : self.hws, :]
162+
logits_cd = logits_abcd[:, self.hws :, :]
163+
logits_00 = torch.concat([self._zero_chip, self._zero_chip], dim=2)
164+
165+
# write cd and 00
166+
self.register[
167+
:, : self.hws, img_window.col_off : img_window.col_off + self.ws
168+
] = logits_cd
169+
self.register[
170+
:, self.hws :, img_window.col_off : img_window.col_off + self.ws
171+
] = logits_00
172+
173+
logits_win = Window(
174+
col_off=img_window.col_off,
175+
row_off=img_window.row_off,
176+
height=min(self.hws, img_window.height),
177+
width=min(self.ws, img_window.width),
178+
)
179+
logits = logits_ab[:, : logits_win.height, : logits_win.width]
180+
elif bottom:
181+
# Need to return a and c sections only
182+
183+
# Update the registry and pop information-complete data
184+
# |0|b| | + pop a+c
185+
# |0|d| |
186+
logits_ac = logits_abcd[:, :, : self.hws]
187+
logits_00 = torch.concat([self._zero_chip, self._zero_chip], dim=1)
188+
logits_bd = logits_abcd[:, :, self.hws :]
189+
190+
# write 00 and bd
191+
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
192+
logits_00 # Not really necessary since this is the last row
193+
)
194+
self.register[
195+
:, :, img_window.col_off + self.hws : img_window.col_off + self.ws
196+
] = logits_bd
197+
198+
logits_win = Window(
199+
col_off=img_window.col_off,
200+
row_off=img_window.row_off,
201+
height=min(self.ws, img_window.height),
202+
width=min(self.hws, img_window.width),
203+
)
204+
logits = logits_ac[:, : img_window.height, : img_window.width]
205+
else:
206+
# Need to return "a" section only
207+
208+
# Update the registry and pop information-complete data
209+
# |c|b| | + pop a
210+
# |0|d| |
211+
logits_a = logits_abcd[:, : self.hws, : self.hws]
212+
logits_c = logits_abcd[:, self.hws :, : self.hws]
213+
logits_c0 = torch.concat([logits_c, self._zero_chip], dim=1)
214+
logits_bd = logits_abcd[:, :, self.hws :]
215+
216+
# write c0
217+
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
218+
logits_c0
219+
)
220+
221+
# write bd
222+
col_off_bd = img_window.col_off + self.hws
223+
self.register[:, :, col_off_bd : col_off_bd + (self.ws - self.hws)] = (
224+
logits_bd
225+
)
226+
227+
logits_win = Window(
228+
col_off=img_window.col_off,
229+
row_off=img_window.row_off,
230+
height=min(self.hws, img_window.height),
231+
width=min(self.hws, img_window.width),
232+
)
233+
logits = logits_a[:, : img_window.height, : img_window.width]
169234

170235
return logits, logits_win

kelp_o_matic/managers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ def __init__(
5757
dtype="uint8",
5858
nodata=0,
5959
)
60-
self.kernel = BartlettHannKernel(crop_size, self.model.device)
6160
self.register = TorchMemoryRegister(
62-
self.input_path, self.model.register_depth, crop_size, self.model.device
61+
image_width=self.reader.width,
62+
register_depth=self.model.register_depth,
63+
window_size=crop_size,
64+
kernel=BartlettHannKernel,
65+
device=self.model.device,
6366
)
6467

6568
def __call__(self):
@@ -107,14 +110,14 @@ def __call__(self):
107110
else:
108111
logits = self.model(crop.unsqueeze(0))[0]
109112

110-
logits = self.kernel(
113+
write_logits, write_window = self.register.step(
111114
logits,
115+
read_window,
112116
top=self.reader.is_top_window(read_window),
113117
bottom=self.reader.is_bottom_window(read_window),
114118
left=self.reader.is_left_window(read_window),
115119
right=self.reader.is_right_window(read_window),
116120
)
117-
write_logits, write_window = self.register.step(logits, read_window)
118121
labels = self.model.post_process(write_logits)
119122

120123
# Write outputs

poetry.lock

Lines changed: 15 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)