Skip to content

Commit

Permalink
Preserve integer types
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Dec 12, 2023
1 parent cdbfbc7 commit e596f21
Show file tree
Hide file tree
Showing 19 changed files with 307 additions and 124 deletions.
35 changes: 27 additions & 8 deletions python/src/pybindings/raster_source_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,20 @@ class PyRasterSourceBase : public RasterSource
{

public:
std::unique_ptr<AbstractRaster<double>> read_box(const Box& box) override
template<typename T>
static std::unique_ptr<NumPyRaster<T>>
make_raster(const Grid<bounded_extent>& grid, const py::array& values, const py::object& nodata)
{
auto rast = std::make_unique<NumPyRaster<T>>(values, grid);

if (!nodata.is_none()) {
rast->set_nodata(nodata.cast<T>());
}

return rast;
}

RasterVariant read_box(const Box& box) override
{
auto cropped_grid = grid().crop(box);

Expand All @@ -62,16 +75,22 @@ class PyRasterSourceBase : public RasterSource
py::array rast_values = read_window(x0, y0, nx, ny);
py::object nodata = nodata_value();

auto rast = std::make_unique<NumPyRaster<double>>(rast_values, cropped_grid);

if (!nodata.is_none()) {
rast->set_nodata(nodata.cast<double>());
if (py::isinstance<py::array_t<std::int8_t>>(rast_values)) {
return make_raster<std::int8_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int16_t>>(rast_values)) {
return make_raster<std::int16_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int32_t>>(rast_values)) {
return make_raster<std::int32_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int64_t>>(rast_values)) {
return make_raster<std::int64_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<float>>(rast_values)) {
return make_raster<float>(cropped_grid, rast_values, nodata);
} else {
return make_raster<double>(cropped_grid, rast_values, nodata);
}

return rast;
}

return nullptr;
return std::make_unique<Raster<double>>(Raster<double>::make_empty());
}

const Grid<bounded_extent>& grid() const override
Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,18 @@ def test_error_rotated_inputs(tmp_path, rast_lib):

with pytest.raises(ValueError, match="Rotated raster"):
exact_extract(rast, square, ["count"])


@pytest.mark.parametrize("dtype", (np.float64, np.float32, np.int32, np.int64))
def test_types_preserved(dtype):

rast = NumPyRasterSource(np.full((3, 3), 1, dtype))

square = make_rect(0, 0, 3, 3)

result = exact_extract(rast, square, "mode")[0]['properties']['mode']

if np.issubdtype(dtype, np.integer):
assert isinstance(result, int)
else:
assert isinstance(result, float)
18 changes: 3 additions & 15 deletions src/coverage_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,9 @@ class CoverageOperation : public Operation
return std::make_unique<CoverageOperation>(*this);
}

/// Method which which a CoverateProcessor can save a value to be applied
/// the next time set_result is called. A bit of a kludge.
void save_coverage(const CoverageValue<double, double>& last_coverage)
template<typename T>
void set_coverage_result(const T& loc, Feature& f_out) const
{
m_last_coverage = last_coverage;
}

void set_result(const StatsRegistry& reg, const Feature& fid, Feature& f_out) const override
{
(void)reg;
(void)fid;

const auto& loc = m_last_coverage;

f_out.set("coverage_fraction", loc.coverage);

if (m_coverage_opts.include_cell) {
Expand All @@ -110,7 +99,7 @@ class CoverageOperation : public Operation
}

if (m_coverage_opts.include_weights && weights != nullptr) {
f_out.set(weights->name(), loc.value);
f_out.set(weights->name(), loc.weight);

Check warning on line 102 in src/coverage_operation.h

View check run for this annotation

Codecov / codecov/patch

src/coverage_operation.h#L102

Added line #L102 was not covered by tests
}
}

Expand All @@ -121,7 +110,6 @@ class CoverageOperation : public Operation

private:
Options m_coverage_opts;
CoverageValue<double, double> m_last_coverage;
};

}
43 changes: 27 additions & 16 deletions src/coverage_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ CoverageProcessor::process()
const auto& opts = op.options();

auto grid = common_grid(m_operations.begin(), m_operations.end());
StatsRegistry dummy; // TODO remove need for this;

RasterSource* values = opts.include_values ? op.values : nullptr;
RasterSource* weights = opts.include_weights ? op.weights : nullptr;

RasterVariant vx = values ? values->read_box(Box::make_empty()) : std::make_unique<Raster<double>>(Raster<double>::make_empty());
RasterVariant wx = weights ? weights->read_box(Box::make_empty()) : std::make_unique<Raster<double>>(Raster<double>::make_empty());

while (m_shp.next()) {
const Feature& f_in = m_shp.feature();
Expand All @@ -54,29 +59,35 @@ CoverageProcessor::process()
for (const auto& subgrid : subdivide(cropped_grid, m_max_cells_in_memory)) {
auto coverage_fractions = raster_cell_intersection(subgrid, m_geos_context, geom);

RasterSource* values = opts.include_values ? op.values : nullptr;
RasterSource* weights = opts.include_weights ? op.weights : nullptr;

std::unique_ptr<AbstractRaster<double>> areas;
if (opts.area_method == CoverageOperation::AreaMethod::CARTESIAN) {
areas = std::make_unique<CartesianAreaRaster<double>>(coverage_fractions.grid());
} else if (opts.area_method == CoverageOperation::AreaMethod::SPHERICAL) {
areas = std::make_unique<SphericalAreaRaster<double>>(coverage_fractions.grid());
}

for (const auto& loc : RasterCoverageIteration<double, double>(coverage_fractions, values, weights, grid, areas.get())) {
auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}
std::visit([this, &op, &f_in, &coverage_fractions, &values, &weights, &grid, &areas](const auto& _v, const auto& _w) {
using ValueType = typename std::remove_reference_t<decltype(*_v)>::value_type;
using WeightType = typename std::remove_reference_t<decltype(*_w)>::value_type;

op.save_coverage(loc);
op.set_result(dummy, f_in, *f_out);
m_output.write(*f_out);
}
for (const auto& loc : RasterCoverageIteration<ValueType, WeightType>(coverage_fractions, values, weights, grid, areas.get())) {

auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}

(void)loc;

op.set_coverage_result(loc, *f_out);
m_output.write(*f_out);
}
},
vx,
wx);

progress();
}
Expand Down
21 changes: 18 additions & 3 deletions src/feature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ Feature::set(const std::string& name, const Feature& f)
set(name, f.get_string(name));
} else if (type == typeid(double)) {
set(name, f.get_double(name));
} else if (type == typeid(std::int8_t)) {
set(name, f.get_int(name));

Check warning on line 31 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L31

Added line #L31 was not covered by tests
} else if (type == typeid(std::int16_t)) {
set(name, f.get_int(name));

Check warning on line 33 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L33

Added line #L33 was not covered by tests
} else if (type == typeid(std::int32_t)) {
set(name, f.get_int(name));
} else if (type == typeid(std::int64_t)) {
set(name, f.get_int(name));

Check warning on line 37 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L36-L37

Added lines #L36 - L37 were not covered by tests
} else if (type == typeid(std::size_t)) {
set(name, f.get_int(name));
} else {
Expand All @@ -37,14 +43,23 @@ Feature::set(const std::string& name, const Feature& f)
}

void
Feature::set(const std::string& name, std::size_t value)
Feature::set(const std::string& name, std::int64_t value)
{
if (value > std::numeric_limits<std::int32_t>::max()) {
throw std::runtime_error("Value is too large to store as 32-bit integer.");
if (value > std::numeric_limits<std::int32_t>::max() || value < std::numeric_limits<std::int32_t>::min()) {
throw std::runtime_error("Value is too small/large to store as 32-bit integer.");

Check warning on line 49 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L49

Added line #L49 was not covered by tests
}
set(name, static_cast<std::int32_t>(value));
}

void
Feature::set(const std::string& name, std::size_t value)
{
if (value > std::numeric_limits<std::int64_t>::max()) {
throw std::runtime_error("Value is too large to store as 64-bit integer.");

Check warning on line 58 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L58

Added line #L58 was not covered by tests
}
set(name, static_cast<std::int64_t>(value));
}

void
Feature::set(const std::string& name, float value)
{
Expand Down
1 change: 1 addition & 0 deletions src/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Feature
virtual void set(const std::string& name, std::int32_t value) = 0;

virtual void set(const std::string& name, float value);
virtual void set(const std::string& name, std::int64_t value);
virtual void set(const std::string& name, std::size_t value);
virtual void set(const std::string& name, const Feature& other);

Expand Down
4 changes: 2 additions & 2 deletions src/feature_sequential_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ FeatureSequentialProcessor::process()
if (op->weighted()) {
auto weights = op->weights->read_box(subgrid.extent().intersection(op->weights->grid().extent()));

m_reg.stats(f_in, *op, store_values).process(*coverage, *values, *weights);
m_reg.update_stats(f_in, *op, *coverage, values, weights, store_values);
} else {
m_reg.stats(f_in, *op, store_values).process(*coverage, *values);
m_reg.update_stats(f_in, *op, *coverage, values, store_values);
}

progress();
Expand Down
5 changes: 5 additions & 0 deletions src/gdal_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ class GDALFeature : public Feature
OGR_F_SetFieldInteger(m_feature, field_index(name), value);
}

void set(const std::string& name, std::int64_t value) override

Check warning on line 114 in src/gdal_feature.h

View check run for this annotation

Codecov / codecov/patch

src/gdal_feature.h#L114

Added line #L114 was not covered by tests
{
OGR_F_SetFieldInteger64(m_feature, field_index(name), value);
}

Check warning on line 117 in src/gdal_feature.h

View check run for this annotation

Codecov / codecov/patch

src/gdal_feature.h#L116-L117

Added lines #L116 - L117 were not covered by tests

void set(const std::string& name, std::size_t value) override
{
if (value > std::numeric_limits<std::int64_t>::max()) {
Expand Down
28 changes: 20 additions & 8 deletions src/gdal_raster_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,26 @@ GDALRasterWrapper::cartesian() const
return srs == nullptr || !OSRIsGeographic(srs);
}

std::unique_ptr<AbstractRaster<double>>
RasterVariant
GDALRasterWrapper::read_box(const Box& box)
{
auto cropped_grid = m_grid.shrink_to_fit(box);
auto vals = std::make_unique<Raster<double>>(cropped_grid);

if (m_has_nodata) {
vals->set_nodata(m_nodata_value);
RasterVariant ret;

auto band_type = GDALGetRasterDataType(m_band);
void* buffer;
GDALDataType read_type;

if (band_type == GDT_Int32) {
auto rast = make_raster<std::int32_t>(cropped_grid);
buffer = rast->data().data();
ret = std::move(rast);
read_type = GDT_Int32;
} else {
auto rast = make_raster<double>(cropped_grid);
buffer = rast->data().data();
ret = std::move(rast);
read_type = GDT_Float64;
}

auto error = GDALRasterIO(m_band,
Expand All @@ -71,18 +83,18 @@ GDALRasterWrapper::read_box(const Box& box)
(int)cropped_grid.row_offset(m_grid),
(int)cropped_grid.cols(),
(int)cropped_grid.rows(),
vals->data().data(),
buffer,
(int)cropped_grid.cols(),
(int)cropped_grid.rows(),
GDT_Float64,
read_type,
0,
0);

if (error) {
throw std::runtime_error("Error reading from raster.");
}

return vals;
return ret;
}

void
Expand Down
12 changes: 11 additions & 1 deletion src/gdal_raster_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class GDALRasterWrapper : public RasterSource
return m_grid;
}

std::unique_ptr<AbstractRaster<double>> read_box(const Box& box) override;
RasterVariant read_box(const Box& box) override;

~GDALRasterWrapper() override;

Expand All @@ -53,6 +53,16 @@ class GDALRasterWrapper : public RasterSource
Grid<bounded_extent> m_grid;

void compute_raster_grid();

template<typename T>
std::unique_ptr<Raster<T>> make_raster(const Grid<bounded_extent>& grid)
{
auto ret = std::make_unique<Raster<T>>(grid);
if (m_has_nodata) {
ret->set_nodata(m_nodata_value);

Check warning on line 62 in src/gdal_raster_wrapper.h

View check run for this annotation

Codecov / codecov/patch

src/gdal_raster_wrapper.h#L62

Added line #L62 was not covered by tests
}
return ret;
}
};
}

Expand Down
Loading

0 comments on commit e596f21

Please sign in to comment.