Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve integer types #57

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions python/src/pybindings/raster_source_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,20 @@ class PyRasterSourceBase : public RasterSource
{

public:
std::unique_ptr<AbstractRaster<double>> read_box(const Box& box) override
template<typename T>
static std::unique_ptr<NumPyRaster<T>>
make_raster(const Grid<bounded_extent>& grid, const py::array& values, const py::object& nodata)
{
auto rast = std::make_unique<NumPyRaster<T>>(values, grid);

if (!nodata.is_none()) {
rast->set_nodata(nodata.cast<T>());
}

return rast;
}

RasterVariant read_box(const Box& box) override
{
auto cropped_grid = grid().crop(box);

Expand All @@ -62,16 +75,22 @@ class PyRasterSourceBase : public RasterSource
py::array rast_values = read_window(x0, y0, nx, ny);
py::object nodata = nodata_value();

auto rast = std::make_unique<NumPyRaster<double>>(rast_values, cropped_grid);

if (!nodata.is_none()) {
rast->set_nodata(nodata.cast<double>());
if (py::isinstance<py::array_t<std::int8_t>>(rast_values)) {
return make_raster<std::int8_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int16_t>>(rast_values)) {
return make_raster<std::int16_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int32_t>>(rast_values)) {
return make_raster<std::int32_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<std::int64_t>>(rast_values)) {
return make_raster<std::int64_t>(cropped_grid, rast_values, nodata);
} else if (py::isinstance<py::array_t<float>>(rast_values)) {
return make_raster<float>(cropped_grid, rast_values, nodata);
} else {
return make_raster<double>(cropped_grid, rast_values, nodata);
}

return rast;
}

return nullptr;
return std::make_unique<Raster<double>>(Raster<double>::make_empty());
}

const Grid<bounded_extent>& grid() const override
Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,18 @@ def test_error_rotated_inputs(tmp_path, rast_lib):

with pytest.raises(ValueError, match="Rotated raster"):
exact_extract(rast, square, ["count"])


@pytest.mark.parametrize("dtype", (np.float64, np.float32, np.int32, np.int64))
def test_types_preserved(dtype):

rast = NumPyRasterSource(np.full((3, 3), 1, dtype))

square = make_rect(0, 0, 3, 3)

result = exact_extract(rast, square, "mode")[0]['properties']['mode']

if np.issubdtype(dtype, np.integer):
assert isinstance(result, int)
else:
assert isinstance(result, float)
18 changes: 3 additions & 15 deletions src/coverage_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,9 @@
return std::make_unique<CoverageOperation>(*this);
}

/// Method which which a CoverateProcessor can save a value to be applied
/// the next time set_result is called. A bit of a kludge.
void save_coverage(const CoverageValue<double, double>& last_coverage)
template<typename T>
void set_coverage_result(const T& loc, Feature& f_out) const
{
m_last_coverage = last_coverage;
}

void set_result(const StatsRegistry& reg, const Feature& fid, Feature& f_out) const override
{
(void)reg;
(void)fid;

const auto& loc = m_last_coverage;

f_out.set("coverage_fraction", loc.coverage);

if (m_coverage_opts.include_cell) {
Expand All @@ -110,7 +99,7 @@
}

if (m_coverage_opts.include_weights && weights != nullptr) {
f_out.set(weights->name(), loc.value);
f_out.set(weights->name(), loc.weight);

Check warning on line 102 in src/coverage_operation.h

View check run for this annotation

Codecov / codecov/patch

src/coverage_operation.h#L102

Added line #L102 was not covered by tests
}
}

Expand All @@ -121,7 +110,6 @@

private:
Options m_coverage_opts;
CoverageValue<double, double> m_last_coverage;
};

}
43 changes: 27 additions & 16 deletions src/coverage_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ CoverageProcessor::process()
const auto& opts = op.options();

auto grid = common_grid(m_operations.begin(), m_operations.end());
StatsRegistry dummy; // TODO remove need for this;

RasterSource* values = opts.include_values ? op.values : nullptr;
RasterSource* weights = opts.include_weights ? op.weights : nullptr;

RasterVariant vx = values ? values->read_box(Box::make_empty()) : std::make_unique<Raster<double>>(Raster<double>::make_empty());
RasterVariant wx = weights ? weights->read_box(Box::make_empty()) : std::make_unique<Raster<double>>(Raster<double>::make_empty());

while (m_shp.next()) {
const Feature& f_in = m_shp.feature();
Expand All @@ -54,29 +59,35 @@ CoverageProcessor::process()
for (const auto& subgrid : subdivide(cropped_grid, m_max_cells_in_memory)) {
auto coverage_fractions = raster_cell_intersection(subgrid, m_geos_context, geom);

RasterSource* values = opts.include_values ? op.values : nullptr;
RasterSource* weights = opts.include_weights ? op.weights : nullptr;

std::unique_ptr<AbstractRaster<double>> areas;
if (opts.area_method == CoverageOperation::AreaMethod::CARTESIAN) {
areas = std::make_unique<CartesianAreaRaster<double>>(coverage_fractions.grid());
} else if (opts.area_method == CoverageOperation::AreaMethod::SPHERICAL) {
areas = std::make_unique<SphericalAreaRaster<double>>(coverage_fractions.grid());
}

for (const auto& loc : RasterCoverageIteration<double, double>(coverage_fractions, values, weights, grid, areas.get())) {
auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}
std::visit([this, &op, &f_in, &coverage_fractions, &values, &weights, &grid, &areas](const auto& _v, const auto& _w) {
using ValueType = typename std::remove_reference_t<decltype(*_v)>::value_type;
using WeightType = typename std::remove_reference_t<decltype(*_w)>::value_type;

op.save_coverage(loc);
op.set_result(dummy, f_in, *f_out);
m_output.write(*f_out);
}
for (const auto& loc : RasterCoverageIteration<ValueType, WeightType>(coverage_fractions, values, weights, grid, areas.get())) {

auto f_out = m_output.create_feature();
if (m_shp.id_field() != "") {
f_out->set(m_shp.id_field(), f_in);
}
for (const auto& col : m_include_cols) {
f_out->set(col, f_in);
}

(void)loc;

op.set_coverage_result(loc, *f_out);
m_output.write(*f_out);
}
},
vx,
wx);

progress();
}
Expand Down
21 changes: 18 additions & 3 deletions src/feature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@
set(name, f.get_string(name));
} else if (type == typeid(double)) {
set(name, f.get_double(name));
} else if (type == typeid(std::int8_t)) {
set(name, f.get_int(name));

Check warning on line 32 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L32

Added line #L32 was not covered by tests
} else if (type == typeid(std::int16_t)) {
set(name, f.get_int(name));

Check warning on line 34 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L34

Added line #L34 was not covered by tests
} else if (type == typeid(std::int32_t)) {
set(name, f.get_int(name));
} else if (type == typeid(std::int64_t)) {
set(name, f.get_int(name));

Check warning on line 38 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L37-L38

Added lines #L37 - L38 were not covered by tests
} else if (type == typeid(std::size_t)) {
set(name, f.get_int(name));
} else {
Expand All @@ -38,14 +44,23 @@
}

void
Feature::set(const std::string& name, std::size_t value)
Feature::set(const std::string& name, std::int64_t value)
{
if (value > std::numeric_limits<std::int32_t>::max()) {
throw std::runtime_error("Value is too large to store as 32-bit integer.");
if (value > std::numeric_limits<std::int32_t>::max() || value < std::numeric_limits<std::int32_t>::min()) {
throw std::runtime_error("Value is too small/large to store as 32-bit integer.");

Check warning on line 50 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L50

Added line #L50 was not covered by tests
}
set(name, static_cast<std::int32_t>(value));
}

void
Feature::set(const std::string& name, std::size_t value)
{
if (value > std::numeric_limits<std::int64_t>::max()) {
throw std::runtime_error("Value is too large to store as 64-bit integer.");

Check warning on line 59 in src/feature.cpp

View check run for this annotation

Codecov / codecov/patch

src/feature.cpp#L59

Added line #L59 was not covered by tests
}
set(name, static_cast<std::int64_t>(value));
}

void
Feature::set(const std::string& name, float value)
{
Expand Down
1 change: 1 addition & 0 deletions src/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Feature
virtual void set(const std::string& name, std::int32_t value) = 0;

virtual void set(const std::string& name, float value);
virtual void set(const std::string& name, std::int64_t value);
virtual void set(const std::string& name, std::size_t value);
virtual void set(const std::string& name, const Feature& other);

Expand Down
4 changes: 2 additions & 2 deletions src/feature_sequential_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ FeatureSequentialProcessor::process()
if (op->weighted()) {
auto weights = op->weights->read_box(subgrid.extent().intersection(op->weights->grid().extent()));

m_reg.stats(f_in, *op, store_values).process(*coverage, *values, *weights);
m_reg.update_stats(f_in, *op, *coverage, values, weights, store_values);
} else {
m_reg.stats(f_in, *op, store_values).process(*coverage, *values);
m_reg.update_stats(f_in, *op, *coverage, values, store_values);
}

progress();
Expand Down
5 changes: 5 additions & 0 deletions src/gdal_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@
OGR_F_SetFieldInteger(m_feature, field_index(name), value);
}

void set(const std::string& name, std::int64_t value) override

Check warning on line 114 in src/gdal_feature.h

View check run for this annotation

Codecov / codecov/patch

src/gdal_feature.h#L114

Added line #L114 was not covered by tests
{
OGR_F_SetFieldInteger64(m_feature, field_index(name), value);
}

Check warning on line 117 in src/gdal_feature.h

View check run for this annotation

Codecov / codecov/patch

src/gdal_feature.h#L116-L117

Added lines #L116 - L117 were not covered by tests

void set(const std::string& name, std::size_t value) override
{
if (value > std::numeric_limits<std::int64_t>::max()) {
Expand Down
28 changes: 20 additions & 8 deletions src/gdal_raster_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,26 @@ GDALRasterWrapper::cartesian() const
return srs == nullptr || !OSRIsGeographic(srs);
}

std::unique_ptr<AbstractRaster<double>>
RasterVariant
GDALRasterWrapper::read_box(const Box& box)
{
auto cropped_grid = m_grid.shrink_to_fit(box);
auto vals = std::make_unique<Raster<double>>(cropped_grid);

if (m_has_nodata) {
vals->set_nodata(m_nodata_value);
RasterVariant ret;

auto band_type = GDALGetRasterDataType(m_band);
void* buffer;
GDALDataType read_type;

if (band_type == GDT_Int32) {
auto rast = make_raster<std::int32_t>(cropped_grid);
buffer = rast->data().data();
ret = std::move(rast);
read_type = GDT_Int32;
} else {
auto rast = make_raster<double>(cropped_grid);
buffer = rast->data().data();
ret = std::move(rast);
read_type = GDT_Float64;
}

auto error = GDALRasterIO(m_band,
Expand All @@ -71,18 +83,18 @@ GDALRasterWrapper::read_box(const Box& box)
(int)cropped_grid.row_offset(m_grid),
(int)cropped_grid.cols(),
(int)cropped_grid.rows(),
vals->data().data(),
buffer,
(int)cropped_grid.cols(),
(int)cropped_grid.rows(),
GDT_Float64,
read_type,
0,
0);

if (error) {
throw std::runtime_error("Error reading from raster.");
}

return vals;
return ret;
}

void
Expand Down
12 changes: 11 additions & 1 deletion src/gdal_raster_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class GDALRasterWrapper : public RasterSource
return m_grid;
}

std::unique_ptr<AbstractRaster<double>> read_box(const Box& box) override;
RasterVariant read_box(const Box& box) override;

~GDALRasterWrapper() override;

Expand All @@ -53,6 +53,16 @@ class GDALRasterWrapper : public RasterSource
Grid<bounded_extent> m_grid;

void compute_raster_grid();

template<typename T>
std::unique_ptr<Raster<T>> make_raster(const Grid<bounded_extent>& grid)
{
auto ret = std::make_unique<Raster<T>>(grid);
if (m_has_nodata) {
ret->set_nodata(m_nodata_value);
}
return ret;
}
};
}

Expand Down
Loading
Loading