Skip to content

Commit 74d74d8

Browse files
committed
Experimental ONNX runner
1 parent ff7b84e commit 74d74d8

File tree

4 files changed

+130
-142
lines changed

4 files changed

+130
-142
lines changed

kelp_o_matic/hann.py

Lines changed: 60 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,151 @@
11
import math
2-
from abc import ABCMeta, abstractmethod
2+
from abc import ABC, abstractmethod
33
from pathlib import Path
44
from typing import Union
55

66
import numpy as np
77
import rasterio
8-
import torch
98
from rasterio.windows import Window
109

1110

1211
# Implementation of paper:
1312
# https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0229839#pone.0229839.ref007
1413

1514

16-
class Kernel(torch.nn.Module, metaclass=ABCMeta):
17-
def __init__(
18-
self, size: int = 512, device: torch.device.type = torch.device("cpu")
19-
):
15+
class Kernel(ABC):
16+
def __init__(self, size: int = 512):
2017
super().__init__()
2118
self.size = size
22-
self.wi = self._init_wi(size, device)
23-
self.wj = self.wi.clone()
19+
self.wi = self._init_wi(size)
20+
self.wj = self.wi.copy()
2421

2522
@staticmethod
2623
@abstractmethod
27-
def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
24+
def _init_wi(size: int) -> np.ndarray:
2825
raise NotImplementedError
2926

3027
def get_kernel(
31-
self,
32-
top: bool = False,
33-
bottom: bool = False,
34-
left: bool = False,
35-
right: bool = False,
36-
) -> torch.Tensor:
37-
wi, wj = self.wi.clone(), self.wj.clone()
28+
self,
29+
top: bool = False,
30+
bottom: bool = False,
31+
left: bool = False,
32+
right: bool = False,
33+
) -> "np.ndarray":
34+
wi, wj = self.wi.copy(), self.wj.copy()
3835

3936
if top:
4037
wi[: self.size // 2] = 1
4138
if bottom:
42-
wi[self.size // 2 :] = 1
39+
wi[self.size // 2:] = 1
4340

4441
if left:
4542
wj[: self.size // 2] = 1
4643
if right:
47-
wj[self.size // 2 :] = 1
44+
wj[self.size // 2:] = 1
4845

49-
return wi.unsqueeze(1) @ wj.unsqueeze(0)
46+
return np.expand_dims(wi, 1) @ np.expand_dims(wj, 0)
5047

5148
def forward(
52-
self,
53-
x: torch.Tensor,
54-
top: bool = False,
55-
bottom: bool = False,
56-
left: bool = False,
57-
right: bool = False,
58-
) -> torch.Tensor:
49+
self,
50+
x: "np.ndarray",
51+
top: bool = False,
52+
bottom: bool = False,
53+
left: bool = False,
54+
right: bool = False,
55+
) -> np.ndarray:
5956
kernel = self.get_kernel(top=top, bottom=bottom, left=left, right=right)
60-
return torch.mul(x, kernel)
57+
return np.multiply(x, kernel)
6158

6259

6360
class HannKernel(Kernel):
6461
@staticmethod
65-
def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
66-
i = torch.arange(0, size, device=device)
67-
return (1 - ((2 * np.pi * i) / (size - 1)).cos()) / 2
62+
def _init_wi(size: int) -> np.ndarray:
63+
i = np.arange(0, size)
64+
return (1 - np.cos(((2 * np.pi * i) / (size - 1)))) / 2
6865

6966

7067
class BartlettHannKernel(Kernel):
7168
@staticmethod
72-
def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
69+
def _init_wi(size: int) -> np.ndarray:
7370
# Follows original paper:
7471
# Ha YH, Pearce JA. A new window and comparison to standard windows.
7572
# IEEE Transactions on Acoustics, Speech, and Signal Processing.
7673
# 1989;37(2):298–301.
77-
i = torch.arange(0, size, device=device)
74+
i = np.arange(0, size)
7875
return (
79-
0.62
80-
- 0.48 * (i / size - 1 / 2).abs()
81-
+ 0.38 * (2 * np.pi * (i / size - 1 / 2).abs()).cos()
76+
0.62
77+
- 0.48 * np.abs((i / size - 1 / 2))
78+
+ 0.38 * np.cos((2 * np.pi * np.abs((i / size - 1 / 2))))
8279
)
8380

8481

8582
class TriangularKernel(Kernel):
8683
@staticmethod
87-
def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
88-
i = torch.arange(0, size, device=device)
89-
return 1 - (2 * i / size - 1).abs()
84+
def _init_wi(size: int) -> np.ndarray:
85+
i = np.arange(0, size)
86+
return 1 - np.abs((2 * i / size - 1))
9087

9188

9289
class BlackmanKernel(Kernel):
9390
@staticmethod
94-
def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
95-
i = torch.arange(0, size, device=device)
91+
def _init_wi(size: int) -> np.ndarray:
92+
i = np.arange(0, size)
9693
return (
97-
0.42
98-
- 0.5 * (2 * np.pi * i / size).cos()
99-
+ 0.08 * (4 * np.pi * i / size).cos()
94+
0.42
95+
- 0.5 * np.cos((2 * np.pi * i / size))
96+
+ 0.08 * np.cos((4 * np.pi * i / size))
10097
)
10198

10299

103-
class TorchMemoryRegister(object):
100+
class NumpyMemoryRegister(object):
104101
def __init__(
105-
self,
106-
image_path: Union[str, Path],
107-
reg_depth: int,
108-
window_size: int,
109-
device: torch.device.type,
102+
self,
103+
image_path: Union[str, Path],
104+
reg_depth: int,
105+
window_size: int,
110106
):
111107
super().__init__()
112108
self.image_path = Path(image_path)
113109
self.n = reg_depth
114110
self.ws = window_size
115111
self.hws = window_size // 2
116-
self.device = device
117112

118113
# Copy metadata from img
119114
with rasterio.open(str(image_path), "r") as src:
120115
src_width = src.width
121116

122117
self.height = self.ws
123118
self.width = (math.ceil(src_width / self.ws) * self.ws) + self.hws
124-
self.register = torch.zeros(
125-
(self.n, self.height, self.width), device=self.device
126-
)
119+
self.register = np.zeros((self.n, self.height, self.width))
127120

128121
@property
129122
def _zero_chip(self):
130-
return torch.zeros(
131-
(self.n, self.hws, self.hws), dtype=torch.float, device=self.device
132-
)
123+
return np.zeros((self.n, self.hws, self.hws), dtype=float)
133124

134-
def step(self, new_logits: torch.Tensor, img_window: Window):
125+
def step(self, new_logits: "np.ndarray", img_window: Window):
135126
# 1. Read data from the registry to update with the new logits
136127
# |a|b| |
137128
# |c|d| |
138-
with torch.no_grad():
139-
logits_abcd = self.register[
140-
:, :, img_window.col_off : img_window.col_off + self.ws
141-
].clone()
142-
logits_abcd += new_logits
129+
logits_abcd = self.register[:, :,
130+
img_window.col_off: img_window.col_off + self.ws].copy()
131+
logits_abcd += new_logits
143132

144133
# Update the registry and pop information-complete data
145134
# |c|b| | + pop a
146135
# |0|d| |
147136
logits_a = logits_abcd[:, : self.hws, : self.hws]
148-
logits_c = logits_abcd[:, self.hws :, : self.hws]
149-
logits_c0 = torch.concat([logits_c, self._zero_chip], dim=1)
150-
logits_bd = logits_abcd[:, :, self.hws :]
137+
logits_c = logits_abcd[:, self.hws:, : self.hws]
138+
logits_c0 = np.concatenate([logits_c, self._zero_chip], axis=1)
139+
logits_bd = logits_abcd[:, :, self.hws:]
151140

152141
# write c0
153142
self.register[
154-
:, :, img_window.col_off : img_window.col_off + self.hws
143+
:, :, img_window.col_off: img_window.col_off + self.hws
155144
] = logits_c0
156145

157146
# write bd
158147
col_off_bd = img_window.col_off + self.hws
159-
self.register[:, :, col_off_bd : col_off_bd + self.hws] = logits_bd
148+
self.register[:, :, col_off_bd: col_off_bd + self.hws] = logits_bd
160149

161150
# Return the information-complete predictions
162151
preds_win = Window(
@@ -165,6 +154,9 @@ def step(self, new_logits: torch.Tensor, img_window: Window):
165154
height=min(self.hws, img_window.height),
166155
width=min(self.hws, img_window.width),
167156
)
168-
preds = logits_a[:, : img_window.height, : img_window.width].softmax(axis=0)
157+
preds = logits_a[:, : img_window.height, : img_window.width]
158+
159+
# Numpy softmax on axis 0
160+
preds = np.exp(preds) / np.sum(np.exp(preds), axis=0)
169161

170162
return preds, preds_win

kelp_o_matic/managers.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from pathlib import Path
33
from typing import Union
44

5+
import numpy as np
56
import rasterio
6-
import torch
77
from rich import print
88
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
99

1010
from kelp_o_matic.geotiff_io import GeotiffReader, GeotiffWriter
11-
from kelp_o_matic.hann import TorchMemoryRegister, BartlettHannKernel
11+
from kelp_o_matic.hann import NumpyMemoryRegister, BartlettHannKernel
1212
from kelp_o_matic.models import _Model
1313
from kelp_o_matic.utils import all_same
1414

@@ -54,9 +54,9 @@ def __init__(
5454
dtype="uint8",
5555
nodata=0,
5656
)
57-
self.kernel = BartlettHannKernel(crop_size, self.model.device)
58-
self.register = TorchMemoryRegister(
59-
self.input_path, self.model.register_depth, crop_size, self.model.device
57+
self.kernel = BartlettHannKernel(crop_size)
58+
self.register = NumpyMemoryRegister(
59+
self.input_path, self.model.register_depth, crop_size
6060
)
6161

6262
def __call__(self):
@@ -74,17 +74,17 @@ def __call__(self):
7474
if self.model.transform:
7575
crop = self.model.transform(crop / self._max_value)
7676

77-
if torch.all(crop == 0):
77+
if np.all(crop == 0):
7878
logits = self.model.shortcut(self.reader.crop_size)
7979
else:
8080
# Zero pad to correct shape
81-
_, h, w = crop.shape
82-
crop = torch.nn.functional.pad(
83-
crop, (0, self.crop_size - w, 0, self.crop_size - h), value=0
84-
)
85-
logits = self.model(crop.unsqueeze(0))[0]
81+
c, h, w = crop.shape
82+
zeros = np.zeros((c, self.crop_size, self.crop_size), dtype=crop.dtype)
83+
zeros[:, :h, :w] = crop
84+
crop = zeros
85+
logits = self.model(np.expand_dims(crop,0))[0]
8686

87-
logits = self.kernel(
87+
logits = self.kernel.forward(
8888
logits,
8989
top=self.reader.is_top_window(read_window),
9090
bottom=self.reader.is_bottom_window(read_window),
@@ -198,10 +198,6 @@ def __call__(self):
198198
with self.progress:
199199
super().__call__()
200200

201-
def on_start(self):
202-
device_emoji = ":rocket:" if self.model.device.type == "cuda" else ":snail:"
203-
print(f"Running with [magenta]{self.model.device} {device_emoji}")
204-
205201
def on_tile_write(self, index: int):
206202
self.progress.update(self.processing_task, completed=index)
207203

0 commit comments

Comments
 (0)