Skip to content

Commit

Permalink
Update LinearModelValidation. Make it inherit from MetaModelValidation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaudin47 committed Jul 2, 2023
1 parent ac163ed commit 193bc7e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ String FunctionalChaosValidation::__repr__() const
{
OSS oss;
oss << "class=" << FunctionalChaosValidation::GetClassName()
<< " functional chaos result=" << functionalChaosResult_;
<< " functional chaos result=" << functionalChaosResult_
<< " kParameter_=" << kParameter_
<< " cvMethod_=" << cvMethod_;
return oss;
}

Expand All @@ -83,6 +85,11 @@ FunctionalChaosResult FunctionalChaosValidation::getResult() const
return functionalChaosResult_;
}

/* Get the K parameter */
UnsignedInteger FunctionalChaosValidation::getKParameter() const
{
return kParameter_;
}

/* Compute cross-validation metamodel predictions */
Sample FunctionalChaosValidation::ComputeMetamodelCrossValidationPredictions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ public:
const CrossValidationMethod method,
const UnsignedInteger & kParameter = ResourceMap::GetAsUnsignedInteger("FunctionalChaosValidation-DefaultKFoldParameter"));

/* Compute cross-validation metamodel predictions */
static Sample ComputeMetamodelCrossValidationPredictions(
const FunctionalChaosResult & functionalChaosResult, const CrossValidationMethod cvMethod,
const UnsignedInteger & kParameter);

/** Virtual constructor */
FunctionalChaosValidation * clone() const override;

Expand Down Expand Up @@ -86,6 +81,11 @@ private:
/** Cross-validation method */
CrossValidationMethod cvMethod_;

/* Compute cross-validation metamodel predictions */
static Sample ComputeMetamodelCrossValidationPredictions(
const FunctionalChaosResult & functionalChaosResult, const CrossValidationMethod cvMethod,
const UnsignedInteger & kParameter);

} ; /* class FunctionalChaosValidation */


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ static const Factory<LinearModelValidation> Factory_LinearModelValidation;

/* Default constructor */
LinearModelValidation::LinearModelValidation()
: PersistentObject()
: MetaModelValidation()
{
// Nothing to do
}

/* Parameter constructor */
LinearModelValidation::LinearModelValidation(const LinearModelResult & linearModelResult,
const CrossValidationMethod method,
const CrossValidationMethod cvMethod,
const UnsignedInteger & kParameter)
: PersistentObject()
: MetaModelValidation(linearModelResult.getOutputSample(),
ComputeMetamodelCrossValidationPredictions(linearModelResult, cvMethod, kParameter))
, linearModelResult_(linearModelResult)
, isInitialized_(false)
{
if ((method != LEAVEONEOUT) and (method != KFOLD))
throw InvalidArgumentException(HERE) << "The method " << method << " is not available.";
cvMethod_ = method;
if ((cvMethod != LEAVEONEOUT) and (cvMethod != KFOLD))
throw InvalidArgumentException(HERE) << "The method " << cvMethod << " is not available.";
cvMethod_ = cvMethod;
if (kParameter < 1)
throw InvalidArgumentException(HERE) << "Cannot set k parameter of K-Fold method to " << kParameter << " which is lower than 1";
const UnsignedInteger sampleSize = linearModelResult_.getSampleResiduals().getSize();
Expand All @@ -80,40 +80,44 @@ String LinearModelValidation::__repr__() const
{
return OSS(true)
<< "class=" << getClassName()
<< ", linearModelResult=" << linearModelResult_;
<< ", linearModelResult=" << linearModelResult_
<< ", kParameter_=" << kParameter_
<< ", cvMethod_=" << cvMethod_;
}

/* Compute cross-validation predictions */
Sample LinearModelValidation::computeMetamodelCrossValidationPredictions() const
Sample LinearModelValidation::ComputeMetamodelCrossValidationPredictions(
const LinearModelResult & linearModelResult, const CrossValidationMethod cvMethod,
const UnsignedInteger & kParameter)
{
// The residuals is ri = g(xi) - tilde{g}(xi) where g is the model
// and tilde(g) is the metamodel.
// Hence the metamodel prediction is tilde{g}(xi) = yi - ri.
const Sample residualsSample(linearModelResult_.getSampleResiduals());
const Sample residualsSample(linearModelResult.getSampleResiduals());
const UnsignedInteger sampleSize = residualsSample.getSize();
const UnsignedInteger basisSize = linearModelResult_.getBasis().getSize();
const UnsignedInteger basisSize = linearModelResult.getBasis().getSize();
if (basisSize > sampleSize)
throw InvalidArgumentException(HERE) << "Error: the sample size is: " << sampleSize <<
" which is lower than the basis size: " << basisSize;
const Sample outputSample(linearModelResult_.getOutputSample());
const Sample outputSample(linearModelResult.getOutputSample());
Sample cvPredictions(sampleSize, 1);
if (cvMethod_ == LEAVEONEOUT)
if (cvMethod == LEAVEONEOUT)
{
const Point diagonalH(linearModelResult_.getLeverages());
const Point diagonalH(linearModelResult.getLeverages());
for (UnsignedInteger i = 0; i < sampleSize; ++i)
cvPredictions(i, 0) = outputSample(i, 0) - residualsSample(i, 0) / (1.0 - diagonalH[i]);
}
else if (cvMethod_ == KFOLD)
else if (cvMethod == KFOLD)
{
SymmetricMatrix projectionMatrix(linearModelResult_.computeProjectionMatrix());
KFoldSplitter splitter(sampleSize, kParameter_);
for (UnsignedInteger foldIndex = 0; foldIndex < kParameter_; ++foldIndex)
SymmetricMatrix projectionMatrix(linearModelResult.computeProjectionMatrix());
KFoldSplitter splitter(sampleSize, kParameter);
for (UnsignedInteger foldIndex = 0; foldIndex < kParameter; ++foldIndex)
{
Indices indicesTest;
const Indices indicesTrain(splitter.generate(indicesTest));
const UnsignedInteger foldSize = indicesTest.getSize();
// Take into account for different fold sizes
const Scalar foldCorrection = std::sqrt(float(sampleSize) / (kParameter_ * foldSize));
const Scalar foldCorrection = std::sqrt(float(sampleSize) / (kParameter * foldSize));
SymmetricMatrix projectionKFoldMatrix(foldSize);
for (UnsignedInteger i1 = 0; i1 < foldSize; ++i1)
for (UnsignedInteger i2 = 0; i2 < 1 + i1; ++i2)
Expand All @@ -126,58 +130,10 @@ Sample LinearModelValidation::computeMetamodelCrossValidationPredictions() const
} // Loop over the folds
}
else
throw InvalidArgumentException(HERE) << "The method " << cvMethod_ << " is not available.";
throw InvalidArgumentException(HERE) << "The method " << cvMethod << " is not available.";
return cvPredictions;
}

/* Initialize */
void LinearModelValidation::initialize()
{
if (!isInitialized_)
{
const Sample metamodelPredictions(computeMetamodelCrossValidationPredictions());
Sample outputSample(linearModelResult_.getOutputSample());
validation_ = MetaModelValidation(outputSample, metamodelPredictions);
isInitialized_ = true;
}
}

/* Compute mean squared error */
Point LinearModelValidation::computeMeanSquaredError()
{
initialize();
return validation_.computeMeanSquaredError();
}

/* Compute residuals */
Sample LinearModelValidation::getResidualSample()
{
initialize();
return validation_.getResidualSample();
}

/* Compute R2 score */
Point LinearModelValidation::computeR2Score()
{
initialize();
Log::Debug(OSS() << "validation_ = " << validation_);
return validation_.computeR2Score();
}

/* Draw */
GridLayout LinearModelValidation::drawValidation()
{
initialize();
return validation_.drawValidation();
}

/* Get residual distribution */
Distribution LinearModelValidation::getResidualDistribution(const Bool smooth)
{
initialize();
return validation_.getResidualDistribution(smooth);
}

/* Get the K parameter */
UnsignedInteger LinearModelValidation::getKParameter() const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ BEGIN_NAMESPACE_OPENTURNS
*/

class OT_API LinearModelValidation :
public PersistentObject
public MetaModelValidation
{
CLASSNAME

Expand All @@ -62,21 +62,6 @@ public:
/** Linear model accessor */
LinearModelResult getLinearModelResult() const;

/** Mean squared error accessor */
Point computeMeanSquaredError();

/** Get R2 score */
Point computeR2Score();

/** Get residuals */
Sample getResidualSample();

/** Get residual distribution */
Distribution getResidualDistribution(const Bool smooth = true);

/** Draw */
GridLayout drawValidation();

/** Get the K parameter */
UnsignedInteger getKParameter() const;

Expand All @@ -91,24 +76,19 @@ private:
/** Initialize the object*/
void initialize();

/** Compute cross-validation metamodel predictions */
Sample computeMetamodelCrossValidationPredictions() const;

/** linear model result */
LinearModelResult linearModelResult_;

/** Initialized ? */
Bool isInitialized_;

/** K-parameter */
UnsignedInteger kParameter_;

/** Cross-validation method */
CrossValidationMethod cvMethod_;

/** MetaModelValidation */
MetaModelValidation validation_;


/** Compute cross-validation predictions */
static Sample ComputeMetamodelCrossValidationPredictions(const LinearModelResult & linearModelResult,
const CrossValidationMethod cvMethod,
const UnsignedInteger & kParameter);
}; /* class LinearModelValidation */

END_NAMESPACE_OPENTURNS
Expand Down
3 changes: 3 additions & 0 deletions lib/test/t_FunctionalChaosValidation_ishigami.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ int main(int, char *[])
// Analytical K-Fold
FunctionalChaosValidation validationKFold(chaosResult, FunctionalChaosValidation::KFOLD, kFoldParameter);
fullprint << "KFold with K = " << kFoldParameter << std::endl;
assert_equal(validationKFold.getKParameter(), kFoldParameter);

// Compute mean squared error
const Point mseKFoldAnalytical(validationKFold.computeMeanSquaredError());
fullprint << "Analytical KFold MSE = " << mseKFoldAnalytical << std::endl;

Expand Down
1 change: 1 addition & 0 deletions lib/test/t_LinearModelValidation_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ int main(int, char *[])
// Create KFold validation
LinearModelValidation validationKFold(result, LinearModelValidation::KFOLD, kFoldParameter);
fullprint << validationKFold.__str__() << std::endl;
assert_equal(validationKFold.getKParameter(), kFoldParameter);

// Compute analytical KFold MSE
const Point mseKFoldAnalytical(validationKFold.computeMeanSquaredError());
Expand Down

0 comments on commit 193bc7e

Please sign in to comment.