From 193bc7e881b91a70037c687f15155c2958a0eb1b Mon Sep 17 00:00:00 2001 From: Michael BAUDIN Date: Sun, 2 Jul 2023 11:52:00 +0200 Subject: [PATCH] Update LinearModelValidation. Make it inherit from MetaModelValidation. --- .../FunctionalChaosValidation.cxx | 9 +- .../openturns/FunctionalChaosValidation.hxx | 10 +- .../LinearModel/LinearModelValidation.cxx | 92 +++++-------------- .../openturns/LinearModelValidation.hxx | 32 ++----- .../t_FunctionalChaosValidation_ishigami.cxx | 3 + lib/test/t_LinearModelValidation_std.cxx | 1 + 6 files changed, 47 insertions(+), 100 deletions(-) diff --git a/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/FunctionalChaosValidation.cxx b/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/FunctionalChaosValidation.cxx index cb99aa93eb4..370850c45c5 100644 --- a/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/FunctionalChaosValidation.cxx +++ b/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/FunctionalChaosValidation.cxx @@ -73,7 +73,9 @@ String FunctionalChaosValidation::__repr__() const { OSS oss; oss << "class=" << FunctionalChaosValidation::GetClassName() - << " functional chaos result=" << functionalChaosResult_; + << " functional chaos result=" << functionalChaosResult_ + << " kParameter_=" << kParameter_ + << " cvMethod_=" << cvMethod_; return oss; } @@ -83,6 +85,11 @@ FunctionalChaosResult FunctionalChaosValidation::getResult() const return functionalChaosResult_; } +/* Get the K parameter */ +UnsignedInteger FunctionalChaosValidation::getKParameter() const +{ + return kParameter_; +} /* Compute cross-validation metamodel predictions */ Sample FunctionalChaosValidation::ComputeMetamodelCrossValidationPredictions( diff --git a/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/openturns/FunctionalChaosValidation.hxx b/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/openturns/FunctionalChaosValidation.hxx index 746cdac10ec..39771350446 100644 --- a/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/openturns/FunctionalChaosValidation.hxx +++ b/lib/src/Uncertainty/Algorithm/MetaModel/FunctionalChaos/openturns/FunctionalChaosValidation.hxx @@ -50,11 +50,6 @@ public: const CrossValidationMethod method, const UnsignedInteger & kParameter = ResourceMap::GetAsUnsignedInteger("FunctionalChaosValidation-DefaultKFoldParameter")); - /* Compute cross-validation metamodel predictions */ - static Sample ComputeMetamodelCrossValidationPredictions( - const FunctionalChaosResult & functionalChaosResult, const CrossValidationMethod cvMethod, - const UnsignedInteger & kParameter); - /** Virtual constructor */ FunctionalChaosValidation * clone() const override; @@ -86,6 +81,11 @@ private: /** Cross-validation method */ CrossValidationMethod cvMethod_; + /* Compute cross-validation metamodel predictions */ + static Sample ComputeMetamodelCrossValidationPredictions( + const FunctionalChaosResult & functionalChaosResult, const CrossValidationMethod cvMethod, + const UnsignedInteger & kParameter); + } ; /* class FunctionalChaosValidation */ diff --git a/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/LinearModelValidation.cxx b/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/LinearModelValidation.cxx index fb7c4adc7e9..64e19e294da 100644 --- a/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/LinearModelValidation.cxx +++ b/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/LinearModelValidation.cxx @@ -43,22 +43,22 @@ static const Factory Factory_LinearModelValidation; /* Default constructor */ LinearModelValidation::LinearModelValidation() - : PersistentObject() + : MetaModelValidation() { // Nothing to do } /* Parameter constructor */ LinearModelValidation::LinearModelValidation(const LinearModelResult & linearModelResult, - const CrossValidationMethod method, + const CrossValidationMethod cvMethod, const UnsignedInteger & kParameter) - : PersistentObject() + : MetaModelValidation(linearModelResult.getOutputSample(), + ComputeMetamodelCrossValidationPredictions(linearModelResult, cvMethod, kParameter)) , linearModelResult_(linearModelResult) - , isInitialized_(false) { - if ((method != LEAVEONEOUT) and (method != KFOLD)) - throw InvalidArgumentException(HERE) << "The method " << method << " is not available."; - cvMethod_ = method; + if ((cvMethod != LEAVEONEOUT) and (cvMethod != KFOLD)) + throw InvalidArgumentException(HERE) << "The method " << cvMethod << " is not available."; + cvMethod_ = cvMethod; if (kParameter < 1) throw InvalidArgumentException(HERE) << "Cannot set k parameter of K-Fold method to " << kParameter << " which is lower than 1"; const UnsignedInteger sampleSize = linearModelResult_.getSampleResiduals().getSize(); @@ -80,40 +80,44 @@ String LinearModelValidation::__repr__() const { return OSS(true) << "class=" << getClassName() - << ", linearModelResult=" << linearModelResult_; + << ", linearModelResult=" << linearModelResult_ + << ", kParameter_=" << kParameter_ + << ", cvMethod_=" << cvMethod_; } /* Compute cross-validation predictions */ -Sample LinearModelValidation::computeMetamodelCrossValidationPredictions() const +Sample LinearModelValidation::ComputeMetamodelCrossValidationPredictions( + const LinearModelResult & linearModelResult, const CrossValidationMethod cvMethod, + const UnsignedInteger & kParameter) { // The residuals is ri = g(xi) - tilde{g}(xi) where g is the model // and tilde(g) is the metamodel. // Hence the metamodel prediction is tilde{g}(xi) = yi - ri. - const Sample residualsSample(linearModelResult_.getSampleResiduals()); + const Sample residualsSample(linearModelResult.getSampleResiduals()); const UnsignedInteger sampleSize = residualsSample.getSize(); - const UnsignedInteger basisSize = linearModelResult_.getBasis().getSize(); + const UnsignedInteger basisSize = linearModelResult.getBasis().getSize(); if (basisSize > sampleSize) throw InvalidArgumentException(HERE) << "Error: the sample size is: " << sampleSize << " which is lower than the basis size: " << basisSize; - const Sample outputSample(linearModelResult_.getOutputSample()); + const Sample outputSample(linearModelResult.getOutputSample()); Sample cvPredictions(sampleSize, 1); - if (cvMethod_ == LEAVEONEOUT) + if (cvMethod == LEAVEONEOUT) { - const Point diagonalH(linearModelResult_.getLeverages()); + const Point diagonalH(linearModelResult.getLeverages()); for (UnsignedInteger i = 0; i < sampleSize; ++i) cvPredictions(i, 0) = outputSample(i, 0) - residualsSample(i, 0) / (1.0 - diagonalH[i]); } - else if (cvMethod_ == KFOLD) + else if (cvMethod == KFOLD) { - SymmetricMatrix projectionMatrix(linearModelResult_.computeProjectionMatrix()); - KFoldSplitter splitter(sampleSize, kParameter_); - for (UnsignedInteger foldIndex = 0; foldIndex < kParameter_; ++foldIndex) + SymmetricMatrix projectionMatrix(linearModelResult.computeProjectionMatrix()); + KFoldSplitter splitter(sampleSize, kParameter); + for (UnsignedInteger foldIndex = 0; foldIndex < kParameter; ++foldIndex) { Indices indicesTest; const Indices indicesTrain(splitter.generate(indicesTest)); const UnsignedInteger foldSize = indicesTest.getSize(); // Take into account for different fold sizes - const Scalar foldCorrection = std::sqrt(float(sampleSize) / (kParameter_ * foldSize)); + const Scalar foldCorrection = std::sqrt(float(sampleSize) / (kParameter * foldSize)); SymmetricMatrix projectionKFoldMatrix(foldSize); for (UnsignedInteger i1 = 0; i1 < foldSize; ++i1) for (UnsignedInteger i2 = 0; i2 < 1 + i1; ++i2) @@ -126,58 +130,10 @@ Sample LinearModelValidation::computeMetamodelCrossValidationPredictions() const } // Loop over the folds } else - throw InvalidArgumentException(HERE) << "The method " << cvMethod_ << " is not available."; + throw InvalidArgumentException(HERE) << "The method " << cvMethod << " is not available."; return cvPredictions; } -/* Initialize */ -void LinearModelValidation::initialize() -{ - if (!isInitialized_) - { - const Sample metamodelPredictions(computeMetamodelCrossValidationPredictions()); - Sample outputSample(linearModelResult_.getOutputSample()); - validation_ = MetaModelValidation(outputSample, metamodelPredictions); - isInitialized_ = true; - } -} - -/* Compute mean squared error */ -Point LinearModelValidation::computeMeanSquaredError() -{ - initialize(); - return validation_.computeMeanSquaredError(); -} - -/* Compute residuals */ -Sample LinearModelValidation::getResidualSample() -{ - initialize(); - return validation_.getResidualSample(); -} - -/* Compute R2 score */ -Point LinearModelValidation::computeR2Score() -{ - initialize(); - Log::Debug(OSS() << "validation_ = " << validation_); - return validation_.computeR2Score(); -} - -/* Draw */ -GridLayout LinearModelValidation::drawValidation() -{ - initialize(); - return validation_.drawValidation(); -} - -/* Get residual distribution */ -Distribution LinearModelValidation::getResidualDistribution(const Bool smooth) -{ - initialize(); - return validation_.getResidualDistribution(smooth); -} - /* Get the K parameter */ UnsignedInteger LinearModelValidation::getKParameter() const { diff --git a/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/openturns/LinearModelValidation.hxx b/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/openturns/LinearModelValidation.hxx index b619c1d2465..099dbc9ba6c 100644 --- a/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/openturns/LinearModelValidation.hxx +++ b/lib/src/Uncertainty/Algorithm/MetaModel/LinearModel/openturns/LinearModelValidation.hxx @@ -38,7 +38,7 @@ BEGIN_NAMESPACE_OPENTURNS */ class OT_API LinearModelValidation : - public PersistentObject + public MetaModelValidation { CLASSNAME @@ -62,21 +62,6 @@ public: /** Linear model accessor */ LinearModelResult getLinearModelResult() const; - /** Mean squared error accessor */ - Point computeMeanSquaredError(); - - /** Get R2 score */ - Point computeR2Score(); - - /** Get residuals */ - Sample getResidualSample(); - - /** Get residual distribution */ - Distribution getResidualDistribution(const Bool smooth = true); - - /** Draw */ - GridLayout drawValidation(); - /** Get the K parameter */ UnsignedInteger getKParameter() const; @@ -91,24 +76,19 @@ private: /** Initialize the object*/ void initialize(); - /** Compute cross-validation metamodel predictions */ - Sample computeMetamodelCrossValidationPredictions() const; - /** linear model result */ LinearModelResult linearModelResult_; - /** Initialized ? */ - Bool isInitialized_; - /** K-parameter */ UnsignedInteger kParameter_; /** Cross-validation method */ CrossValidationMethod cvMethod_; - - /** MetaModelValidation */ - MetaModelValidation validation_; - + + /** Compute cross-validation predictions */ + static Sample ComputeMetamodelCrossValidationPredictions(const LinearModelResult & linearModelResult, + const CrossValidationMethod cvMethod, + const UnsignedInteger & kParameter); }; /* class LinearModelValidation */ END_NAMESPACE_OPENTURNS diff --git a/lib/test/t_FunctionalChaosValidation_ishigami.cxx b/lib/test/t_FunctionalChaosValidation_ishigami.cxx index 71d9318213d..e0d0567de0b 100644 --- a/lib/test/t_FunctionalChaosValidation_ishigami.cxx +++ b/lib/test/t_FunctionalChaosValidation_ishigami.cxx @@ -125,6 +125,9 @@ int main(int, char *[]) // Analytical K-Fold FunctionalChaosValidation validationKFold(chaosResult, FunctionalChaosValidation::KFOLD, kFoldParameter); fullprint << "KFold with K = " << kFoldParameter << std::endl; + assert_equal(validationKFold.getKParameter(), kFoldParameter); + + // Compute mean squared error const Point mseKFoldAnalytical(validationKFold.computeMeanSquaredError()); fullprint << "Analytical KFold MSE = " << mseKFoldAnalytical << std::endl; diff --git a/lib/test/t_LinearModelValidation_std.cxx b/lib/test/t_LinearModelValidation_std.cxx index 338434d0d37..bcb51049605 100644 --- a/lib/test/t_LinearModelValidation_std.cxx +++ b/lib/test/t_LinearModelValidation_std.cxx @@ -110,6 +110,7 @@ int main(int, char *[]) // Create KFold validation LinearModelValidation validationKFold(result, LinearModelValidation::KFOLD, kFoldParameter); fullprint << validationKFold.__str__() << std::endl; + assert_equal(validationKFold.getKParameter(), kFoldParameter); // Compute analytical KFold MSE const Point mseKFoldAnalytical(validationKFold.computeMeanSquaredError());