Skip to content

Commit

Permalink
MIR- eckit::geo ProxyWeightedMethod
Browse files Browse the repository at this point in the history
  • Loading branch information
pmaciel committed Nov 23, 2024
1 parent c9e0db9 commit 837dbd3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 56 deletions.
116 changes: 65 additions & 51 deletions src/mir/method/ProxyWeightedMethod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
#include "mir/method/ProxyWeightedMethod.h"

#include <ostream>
#include <vector>

#include "eckit/log/JSON.h"
#include "eckit/utils/MD5.h"

#include "mir/param/MIRParametrisation.h"
#include "mir/repres/Representation.h"
#include "mir/util/Exceptions.h"
#include "mir/util/Log.h"
#include "mir/util/allocator/InPlaceAllocator.h"


Expand All @@ -31,26 +29,72 @@ namespace mir::method {
namespace {


struct StructuredBicubic final : public ProxyWeightedMethod {
explicit StructuredBicubic(const param::MIRParametrisation& param) :
ProxyWeightedMethod(param, "structured-bicubic", "serial-halo to serial") {}
};
#define ATLAS_METHOD(Type, Name) \
struct Type final : ProxyWeightedMethod { \
explicit Type(const param::MIRParametrisation& param) : ProxyWeightedMethod(param, Name) {} \
};

ATLAS_METHOD(StructuredBilinear, "structured-bilinear");
ATLAS_METHOD(StructuredBiquasicubic, "structured-biquasicubic");
ATLAS_METHOD(StructuredBicubic, "structured-bicubic");
ATLAS_METHOD(FiniteElement, "finite-element");
ATLAS_METHOD(ConservativeSphericalPolygon, "conservative-spherical-polygon");
ATLAS_METHOD(GridBoxAverage, "grid-box-average");
ATLAS_METHOD(GridBoxMaximum, "grid-box-maximum");

} // namespace
// "nearest-neighbour" (knn)
// "k-nearest-neighbours" (knn)
// "cubedsphere-bilinear"
// "regional-linear-2d" (structured)
// "spherical-vector"


#undef ATLAS_INTERPOL


static const MethodFactory* METHODS[]{
new MethodBuilder<StructuredBicubic>("structured-bicubic"),
const MethodFactory* MIR_METHODS[]{
new MethodBuilder<StructuredBicubic>("atlas-structured-bicubic"),
new MethodBuilder<StructuredBilinear>("atlas-structured-bilinear"),
new MethodBuilder<StructuredBiquasicubic>("atlas-structured-biquasicubic"),
new MethodBuilder<FiniteElement>("atlas-finite-element"),
new MethodBuilder<ConservativeSphericalPolygon>("atlas-conservative-spherical-polygon"),
new MethodBuilder<GridBoxAverage>("atlas-grid-box-average"),
new MethodBuilder<GridBoxMaximum>("atlas-grid-box-maximum"),
};


ProxyWeightedMethod::ProxyWeightedMethod(const param::MIRParametrisation& param, const std::string& interpolation_type,
const std::string& renumber_type) :
} // namespace


ProxyWeightedMethod::ProxyWeightedMethod(const param::MIRParametrisation& param,
const std::string& interpolation_type) :
MethodWeighted(param), type_(interpolation_type) {
interpol_.set("matrix_free", false);
interpol_.set("type", interpolation_type);
renumber_.set("type", renumber_type);
}


void ProxyWeightedMethod::foldSourceHalo(const atlas::Interpolation& interpol, size_t Nr, size_t Nc,
WeightMatrix& W) const {
ASSERT(Nr == W.rows());
ASSERT(Nc < W.cols());

const auto& fs = interpol.source();
const auto global_index{atlas::array::make_view<atlas::gidx_t, 1>(fs.global_index())};
ASSERT(global_index.size() == W.cols());

auto* a = const_cast<eckit::linalg::Scalar*>(W.data());
for (auto c = Nc; c < W.cols(); ++c) {
ASSERT(1 <= global_index[c] && global_index[c] < Nc + 1); // (global indexing is 1-based)
a[global_index[c] - 1] += a[c];
a[c] = 0.;
}

W.prune();
WeightMatrix M(new util::allocator::InPlaceAllocator{
Nr, Nc, W.nonZeros(), const_cast<eckit::linalg::Index*>(W.outer()),
const_cast<eckit::linalg::Index*>(W.inner()), const_cast<eckit::linalg::Scalar*>(W.data())});
W.swap(M);
}


Expand All @@ -60,7 +104,7 @@ const char* ProxyWeightedMethod::name() const {


int ProxyWeightedMethod::version() const {
return -1;
return 1;
}


Expand All @@ -79,64 +123,34 @@ bool ProxyWeightedMethod::sameAs(const Method& other) const {

const auto* o = dynamic_cast<const ProxyWeightedMethod*>(&other);
return (o != nullptr) && name() == std::string{o->name()} && digest(interpol_) == digest(o->interpol_) &&
digest(renumber_) == digest(o->renumber_) && MethodWeighted::sameAs(*o);
MethodWeighted::sameAs(*o);
}


void ProxyWeightedMethod::print(std::ostream& out) const {
out << "ProxyWeightedMethod[interpolation=" << interpol_ << ",renumber=" << renumber_ << ",";
out << "ProxyWeightedMethod[interpolation=" << interpol_ << ",";
MethodWeighted::print(out);
out << "]";
}


void ProxyWeightedMethod::json(eckit::JSON& j) const {
j.startObject();
j << "options" << interpol_.json(eckit::JSON::Formatting::compact());
j << "interpolation" << interpol_.json(eckit::JSON::Formatting::compact());
MethodWeighted::json(j);
j.endObject();
}


void ProxyWeightedMethod::assemble(util::MIRStatistics&, WeightMatrix& W, const repres::Representation& in,
const repres::Representation& out) const {
// interpolation matrix build and assign (with a halo on the source grid)
// build matrix (with a halo on the source grid), move out of cache
atlas::Interpolation interpol{interpol_, in.atlasGrid(), out.atlasGrid()};
{
atlas::interpolation::MatrixCache cache{interpol};
const auto& M = cache.matrix();
W.swap(const_cast<eckit::linalg::SparseMatrix&>(M));
}

// fold serial+halo into serial
const auto Nr = out.numberOfPoints();
const auto Nc = in.numberOfPoints();
ASSERT(Nr == W.rows());
ASSERT(Nc < W.cols());

{
const auto& fs = interpol.source();
const auto gidx{atlas::array::make_view<atlas::gidx_t, 1>(fs.global_index())};

auto* a = const_cast<eckit::linalg::Scalar*>(W.data());
for (auto c = Nc; c < W.cols(); ++c) {
ASSERT(1 <= gidx[c] && gidx[c] < Nc + 1); // (global indexing is 1-based)
a[gidx[c] - 1] += a[c];
a[c] = 0.;
}
W.prune();
}

// reshape matrix
WeightMatrix M(new util::allocator::InPlaceAllocator{
Nr, Nc, W.nonZeros(), const_cast<eckit::linalg::Index*>(W.outer()),
const_cast<eckit::linalg::Index*>(W.inner()), const_cast<eckit::linalg::Scalar*>(W.data())});
W.swap(M);
}

atlas::interpolation::MatrixCache cache{interpol};
W.swap(const_cast<eckit::linalg::SparseMatrix&>(cache.matrix()));

bool ProxyWeightedMethod::validateMatrixWeights() const {
return true;
// fold source grid halo (from serial + halo into serial)
foldSourceHalo(interpol, out.numberOfPoints(), in.numberOfPoints(), W);
}


Expand Down
11 changes: 6 additions & 5 deletions src/mir/method/ProxyWeightedMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@ class ProxyWeightedMethod : public MethodWeighted {
protected:
// -- Constructor

ProxyWeightedMethod(const param::MIRParametrisation&, const std::string& interpolation_type,
const std::string& renumber_type = "");
ProxyWeightedMethod(const param::MIRParametrisation&, const std::string& interpolation_type);

// -- Destructor

~ProxyWeightedMethod() override = default;

protected:
// -- Methods

void foldSourceHalo(const atlas::Interpolation&, size_t Nr, size_t Nc, WeightMatrix&) const;

private:
// -- Members

const std::string type_;
atlas::util::Config interpol_;
atlas::util::Config renumber_;

// -- Overridden methods

Expand All @@ -51,8 +54,6 @@ class ProxyWeightedMethod : public MethodWeighted {

void assemble(util::MIRStatistics&, WeightMatrix&, const repres::Representation& in,
const repres::Representation& out) const override;

bool validateMatrixWeights() const override;
};


Expand Down

0 comments on commit 837dbd3

Please sign in to comment.