Skip to content

Commit 96c18ac

Browse files
Rasterized numpy array
1 parent d07e597 commit 96c18ac

File tree

4 files changed

+129
-96
lines changed

4 files changed

+129
-96
lines changed

threedigrid_builder/grid/grid.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from threedigrid_builder.constants import ContentType, LineType, NodeType, WKT_VERSION
2424
from threedigrid_builder.exceptions import SchematisationError
2525
from threedigrid_builder.grid import zero_d
26+
from threedigrid_builder.utils import Dataset
2627

2728
from . import connection_nodes as connection_nodes_module
2829
from . import dem_average_area as dem_average_area_module
@@ -920,23 +921,10 @@ def apply_cutlines(self, cutlines, dem_path):
920921
# Now we have the list of all cut cells and their corresponding fragments
921922
print([idx + 1 for idx in node_fragment.keys()]) # QGIS idx start at 1
922923

923-
# Generate mask
924-
924+
# read DEM to retrieve properties
925925
raster_dataset = gdal.Open(str(dem_path), gdal.GA_ReadOnly)
926-
# array = np.full(shape=(1, subgrid_meta["height"], subgrid_meta["width"]), fill_value=-1, dtype=np.int32)
927-
target_ds = gdal.GetDriverByName("GTiff").Create(
928-
"test2.tif",
929-
raster_dataset.RasterXSize,
930-
raster_dataset.RasterYSize,
931-
1,
932-
gdal.GDT_Int32,
933-
)
934-
target_ds.SetGeoTransform(raster_dataset.GetGeoTransform())
935-
target_ds.SetProjection(raster_dataset.GetProjection())
936-
937-
band = target_ds.GetRasterBand(1)
938-
band.SetNoDataValue(-9999)
939926

927+
# Create an OGR memory layer with the geometry
940928
driver = ogr.GetDriverByName("Memory")
941929
data_source = driver.CreateDataSource("")
942930
layer = data_source.CreateLayer(
@@ -956,7 +944,38 @@ def apply_cutlines(self, cutlines, dem_path):
956944
feature.SetGeometry(polygon)
957945
layer.CreateFeature(feature)
958946

959-
gdal.RasterizeLayer(target_ds, [1], layer, options=["ATTRIBUTE=id"])
947+
no_data_value = -9999
948+
# Write to tiff
949+
export_tiff = False
950+
if export_tiff:
951+
target_ds = gdal.GetDriverByName("GTiff").Create(
952+
"test.tif",
953+
raster_dataset.RasterXSize,
954+
raster_dataset.RasterYSize,
955+
1,
956+
gdal.GDT_Int32,
957+
)
958+
target_ds.SetGeoTransform(raster_dataset.GetGeoTransform())
959+
target_ds.SetProjection(raster_dataset.GetProjection())
960+
961+
band = target_ds.GetRasterBand(1)
962+
band.SetNoDataValue(no_data_value)
963+
gdal.RasterizeLayer(target_ds, [1], layer, options=["ATTRIBUTE=id"])
964+
965+
# Write to array
966+
array = np.full(
967+
shape=(1, raster_dataset.RasterYSize, raster_dataset.RasterXSize),
968+
fill_value=no_data_value,
969+
dtype=np.int32,
970+
)
971+
dataset_kwargs = {
972+
"no_data_value": no_data_value,
973+
"geo_transform": raster_dataset.GetGeoTransform(),
974+
}
975+
976+
with Dataset(array, **dataset_kwargs) as dataset:
977+
gdal.RasterizeLayer(dataset, (1,), layer, options=["ATTRIBUTE=id"])
978+
960979
assert True
961980

962981
@staticmethod

threedigrid_builder/grid/grid_refinement.py

Lines changed: 2 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
22
import shapely
3-
from osgeo import gdal, gdal_array, ogr
3+
from osgeo import gdal, ogr
44

55
from threedigrid_builder.base import Array
6+
from threedigrid_builder.utils import Dataset
67

78

89
class GridRefinement:
@@ -92,82 +93,3 @@ def rasterize(geoms, values, origin, width, height, cell_size, no_data_value):
9293
gdal.RasterizeLayer(dataset, (1,), layer, options=options)
9394

9495
return array[0]
95-
96-
97-
def create_dataset(array, geo_transform=None, projection=None, no_data_value=None):
98-
"""
99-
Create and return a gdal dataset.
100-
101-
:param array: A numpy array.
102-
:param geo_transform: 6-tuple of floats
103-
:param projection: wkt projection string
104-
:param no_data_value: integer or float
105-
106-
This is the fastest way to get a gdal dataset from a numpy array, but
107-
keep a reference to the array around, or a segfault will occur. Also,
108-
don't forget to call FlushCache() on the dataset after any operation
109-
that affects the array.
110-
"""
111-
# prepare dataset name pointing to array
112-
datapointer = array.ctypes.data
113-
bands, lines, pixels = array.shape
114-
datatypecode = gdal_array.NumericTypeCodeToGDALTypeCode(array.dtype.type)
115-
datatype = gdal.GetDataTypeName(datatypecode)
116-
bandoffset, lineoffset, pixeloffset = array.strides
117-
118-
dataset_name_template = (
119-
"MEM:::"
120-
"DATAPOINTER={datapointer},"
121-
"PIXELS={pixels},"
122-
"LINES={lines},"
123-
"BANDS={bands},"
124-
"DATATYPE={datatype},"
125-
"PIXELOFFSET={pixeloffset},"
126-
"LINEOFFSET={lineoffset},"
127-
"BANDOFFSET={bandoffset}"
128-
)
129-
dataset_name = dataset_name_template.format(
130-
datapointer=datapointer,
131-
pixels=pixels,
132-
lines=lines,
133-
bands=bands,
134-
datatype=datatype,
135-
pixeloffset=pixeloffset,
136-
lineoffset=lineoffset,
137-
bandoffset=bandoffset,
138-
)
139-
140-
# access the array memory as gdal dataset
141-
dataset = gdal.Open(dataset_name, gdal.GA_Update)
142-
143-
# set additional properties from kwargs
144-
if geo_transform is not None:
145-
dataset.SetGeoTransform(geo_transform)
146-
if projection is not None:
147-
dataset.SetProjection(projection)
148-
if no_data_value is not None:
149-
for i in range(len(array)):
150-
dataset.GetRasterBand(i + 1).SetNoDataValue(no_data_value)
151-
152-
return dataset
153-
154-
155-
class Dataset(object):
156-
"""
157-
Usage:
158-
>>> with Dataset(array) as dataset:
159-
... # do gdal things.
160-
"""
161-
162-
def __init__(self, array, **kwargs):
163-
self.array = array
164-
self.dataset = create_dataset(array, **kwargs)
165-
166-
def __enter__(self):
167-
return self.dataset
168-
169-
def __exit__(self, exc_type, exc_value, traceback):
170-
self.close()
171-
172-
def close(self):
173-
self.dataset.FlushCache()

threedigrid_builder/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Some miscelanious utils (also to prevent circular imports)
2+
"""
3+
4+
from .array_dataset import * # NOQA
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
try:
2+
from osgeo import gdal, gdal_array
3+
4+
gdal.UseExceptions()
5+
except ImportError:
6+
gdal = None
7+
8+
__all__ = ["Dataset"]
9+
10+
11+
# This code is a direct copy of dask-geomodelling
12+
def create_dataset(array, geo_transform=None, projection=None, no_data_value=None):
13+
"""
14+
Create and return a gdal dataset.
15+
16+
:param array: A numpy array.
17+
:param geo_transform: 6-tuple of floats
18+
:param projection: wkt projection string
19+
:param no_data_value: integer or float
20+
21+
This is the fastest way to get a gdal dataset from a numpy array, but
22+
keep a reference to the array around, or a segfault will occur. Also,
23+
don't forget to call FlushCache() on the dataset after any operation
24+
that affects the array.
25+
"""
26+
# prepare dataset name pointing to array
27+
datapointer = array.ctypes.data
28+
bands, lines, pixels = array.shape
29+
datatypecode = gdal_array.NumericTypeCodeToGDALTypeCode(array.dtype.type)
30+
datatype = gdal.GetDataTypeName(datatypecode)
31+
bandoffset, lineoffset, pixeloffset = array.strides
32+
33+
dataset_name_template = (
34+
"MEM:::"
35+
"DATAPOINTER={datapointer},"
36+
"PIXELS={pixels},"
37+
"LINES={lines},"
38+
"BANDS={bands},"
39+
"DATATYPE={datatype},"
40+
"PIXELOFFSET={pixeloffset},"
41+
"LINEOFFSET={lineoffset},"
42+
"BANDOFFSET={bandoffset}"
43+
)
44+
dataset_name = dataset_name_template.format(
45+
datapointer=datapointer,
46+
pixels=pixels,
47+
lines=lines,
48+
bands=bands,
49+
datatype=datatype,
50+
pixeloffset=pixeloffset,
51+
lineoffset=lineoffset,
52+
bandoffset=bandoffset,
53+
)
54+
55+
# access the array memory as gdal dataset
56+
dataset = gdal.Open(dataset_name, gdal.GA_Update)
57+
58+
# set additional properties from kwargs
59+
if geo_transform is not None:
60+
dataset.SetGeoTransform(geo_transform)
61+
if projection is not None:
62+
dataset.SetProjection(projection)
63+
if no_data_value is not None:
64+
for i in range(len(array)):
65+
dataset.GetRasterBand(i + 1).SetNoDataValue(no_data_value)
66+
67+
return dataset
68+
69+
70+
class Dataset(object):
71+
"""
72+
Usage:
73+
>>> with Dataset(array) as dataset:
74+
... # do gdal things.
75+
"""
76+
77+
def __init__(self, array, **kwargs):
78+
self.array = array
79+
self.dataset = create_dataset(array, **kwargs)
80+
81+
def __enter__(self):
82+
return self.dataset
83+
84+
def __exit__(self, exc_type, exc_value, traceback):
85+
self.close()
86+
87+
def close(self):
88+
self.dataset.FlushCache()

0 commit comments

Comments
 (0)