Skip to content

Commit

Permalink
added possibiltiy for k-fold cross-validation to regression package.
Browse files Browse the repository at this point in the history
bodirsky committed Nov 11, 2019
1 parent a42dac4 commit a4c9f7c
Showing 6 changed files with 93 additions and 20 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: mrregression
Type: Package
Title: Regression analysis for model parametrization
Version: 3.12.3
Date: 2019-05-29
Version: 3.16.0
Date: 2019-11-11
Author: Benjamin Leon Bodirsky, Antonia Walther, Xiaoxi Wang, Abhijeet Mishra, Eleonora Martinelli
Maintainer: Benjamin Leon Bodirsky <bodirsky@pik-potsdam.de>
Description: Model estimates parameters of model functions.
@@ -24,5 +24,5 @@ Imports:
License: LGPL-3 | file LICENSE
LazyData: no
RoxygenNote: 6.1.1
ValidationKey: 56354535
ValidationKey: 57546760
Encoding: UTF-8
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ importFrom(stats,nls)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,resid)
importFrom(stats,runif)
importFrom(stats,shapiro.test)
importFrom(stats,var)
importFrom(stats,weighted.mean)
26 changes: 22 additions & 4 deletions R/calcCollectRegressionData.R
Original file line number Diff line number Diff line change
@@ -15,13 +15,14 @@
#' @importFrom magpiesets findset
#' @importFrom moinput toolFAOcombine
#' @importFrom stats quantile
#' @importFrom stats runif
#' @import magclass
#' @import madrat
#' @export

calcCollectRegressionData <- function(datasources)
{
combined<-list()
calcCollectRegressionData <- function(datasources){

combined<-list()

if ("wooddemand" %in% datasources) {
wooddemand <- calcOutput("FAOForestryDemand",aggregate = FALSE)
@@ -130,7 +131,7 @@ calcCollectRegressionData <- function(datasources)
}

if ("intake_demography" %in% datasources) {
intake <- calcOutput("Intake",convert=FALSE, modelinput=FALSE, standardize=FALSE, method="Froehle", aggregate=FALSE)
intake <- calcOutput("Intake",convert=FALSE, modelinput=FALSE, standardize=FALSE, method="schofield", aggregate=FALSE)
intake<-collapseNames(intake[,,"SSP2"][,,c("F","M")])
getNames(intake)<-paste0("intake_",sub(x = getNames(intake),pattern = "\\.",replacement = "_"))
getSets(intake)<-c("region","year","intake")
@@ -370,6 +371,23 @@ calcCollectRegressionData <- function(datasources)
getYears(combined[[1]])
combined$climate<-CZ
}

if (any(grepl("crossvalid",datasources))) {
code=datasources[grep("crossvalid",datasources)]
code2=strsplit(code,"_")
randomseed = as.integer(substring(code2[[1]][2],5))
k = as.integer(substring(code2[[1]][3],2))
# format: crossvalid_seedX_kY
# X is the random seed,
# Y is the number of drawings

countries = toolGetMapping("iso_country.csv", where = "moinput")
years=paste0("y",1961:2020)
sampleset=new.magpie(cells_and_regions = countries$x,years = years,names = code)
set.seed(42); sampleset[,,] <- round(runif(length(sampleset))*k+0.5)

combined$crossvalid<-sampleset
}


mbindCommonDimensions <- function(magpielist){
58 changes: 46 additions & 12 deletions R/nlsregression.R
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
#' @param toPlot "all", "frame" (axis etc), "observations" (points), "regressionline" (line), "infos" (parameters, R2)
#' @param regressioncolor color of regression line and paramter text
#' @param weight_threshold if numeric, all countries below this threshold will be excluded (e.g. to exclude minor islands)
#' @param crossvalid vector with boolean values, indicating which data should be excluded from sampling and rather be used for validation
#' @param ... will be passed on to function nls
#' @return A nice picture and regression parameters or eventually some errors.
#' @author Benjamin Leon Bodirsky, Susanne Rolinski, Xiaoxi Wang
@@ -86,6 +87,7 @@ nlsregression <- function(func, # y ~ a*x/b+x
plot_x_function="ignore",
regressioncolor="blue",
weight_threshold=NULL,
crossvalid=NULL,
...)
{
rounding_helper<-function(x) {
@@ -281,6 +283,7 @@ nlsregression <- function(func, # y ~ a*x/b+x
if(is.magpie(y)){y<-as.vector(y)}
if(is.magpie(z)){z<-as.vector(z)}
if(is.magpie(weight)){weight<-as.vector(weight)}
if(is.magpie(crossvalid)){crossvalid<-as.vector(crossvalid)}


if(plot_x_function!="ignore"){stop("argument plot_x_function is depreciated, please remove")}
@@ -301,34 +304,53 @@ nlsregression <- function(func, # y ~ a*x/b+x

# remove all NA
naVec = y * x
if (!is.null(z))
{
if (!is.null(z)){
naVec = naVec * z
}
if (!is.null(weight))
{
if (!is.null(weight)){
if(!is.null(weight_threshold)){
weight[weight<weight_threshold]<-NA
}
naVec = naVec * weight
}
if (is.null(crossvalid)){
crossvalid = naVec
crossvalid[] = 0
}
for(i in length(naVec):1) {
if(is.na(naVec[i]))
{
if(is.na(naVec[i])) {
y = y[-i]
x = x[-i]
if (!is.null(z))
{
if (!is.null(z)) {
z = z[-i]
}
if (!is.null(weight))
{
if (!is.null(weight)) {
weight = weight[-i]
}
if (!is.null(crossvalid)) {
crossvalid = crossvalid[-i]
}
}
}

#
# split dataset in regression and validation data
if(any(crossvalid>0)){
y_valid = y[which(crossvalid==1)]
y = y[which(crossvalid==0)]
x_valid = x[which(crossvalid==1)]
x = x[which(crossvalid==0)]

if (!is.null(z)) {
z_valid = z[which(crossvalid==1)]
z = z[which(crossvalid==0)]
}

if (!is.null(weight)) {
weight_valid = weight[which(crossvalid==1)]
weight = weight[which(crossvalid==0)]
}

}


if (is.null(weight)) {
@@ -443,7 +465,17 @@ nlsregression <- function(func, # y ~ a*x/b+x

standarderror=(sum((prediction-observation)^2)/length(observation))^0.5

### Out of sample cross-validation

if(any(crossvalid>0)){
y_valid_predict = predict(opt,data.frame('x'=x_valid))
out_of_sample_R2_unweighted = max(cor(y_valid_predict, y_valid),0)^2
out_of_sample_R2_weighted = max(corr(matrix(data = c(y_valid,y_valid_predict),ncol = 2),w = weight_valid),0)^2
} else {
out_of_sample_R2_unweighted=NULL
out_of_sample_R2_weighted=NULL
}

### transforming formulas into expression or functions

formula2<-gsub(" ", "", format(func)[[3]], fixed = TRUE)
@@ -470,7 +502,9 @@ nlsregression <- function(func, # y ~ a*x/b+x
loglik = loglik,
norm_test = norm_test,
# bp_test = bp_test,
robust_out = robust_out
robust_out = robust_out,
out_of_sample_R2_unweighted,
out_of_sample_R2_weighted
)

if(any(c("all","regression")%in%toPlot)) {
5 changes: 5 additions & 0 deletions R/toolCollectRegressionVariables.R
Original file line number Diff line number Diff line change
@@ -158,6 +158,11 @@ toolCollectRegressionVariables<-function(indicators){
datasources=c(datasources,"intake_standardized_demography")
}

if (any(grepl("crossvalid",indicators))) {
code=indicators[grep("crossvalid",indicators)]
datasources=c(datasources,code)
}


data<-calcOutput("CollectRegressionData",datasources=datasources,aggregate = FALSE)[,,indicators]
#data<-calcOutput("CollectRegressionData",aggregate = FALSE)
17 changes: 16 additions & 1 deletion R/toolRegression.R
Original file line number Diff line number Diff line change
@@ -12,6 +12,11 @@
#' @param countries_nlsAddLines the number of weightiest countries or the name of countries that shall be plotted by lines in the plot
#' @param weight the weight
#' @param x_log10 passed on to nlsregression()
#' @param crossvalid_sample sample name from moinput used for crossvalidation. Name is built as follows:
#' crossvalid_seedX_kY
#' X is the random seed,
#' Y is the number of drawings. The combination of all drawings is the full sample.
#' @param crossvalid_drawing selected drawing of k in crossvalidsample
#' @param ... further attributes that will be handed on to nlsregression():
#'
#' An additional explanatory variable z can be added.
@@ -58,11 +63,13 @@ toolRegression<-function(denominator,
countries_nlsAddLines=NULL,
weight="pop",
x_log10=FALSE,
crossvalid_sample = NULL,
crossvalid_drawing=1,
...
)
{
if (is.null(data)){
data<-toolCollectRegressionVariables(indicators=c(denominator,quotient,x,z,weight))
data<-toolCollectRegressionVariables(indicators=c(denominator,quotient,x,z,weight,crossvalid_sample))
}

if(is.null(xlab)){
@@ -103,6 +110,13 @@ toolRegression<-function(denominator,
weight = dimSums(data[,,weight],dim=3)
}

if(is.null(crossvalid_sample)){
crossvalid=NULL
} else {
crossvalid=data[,,crossvalid_sample]
crossvalid[,,] = (crossvalid==crossvalid_drawing)
}

denom = dimSums(data[,,denom],dim=3)

#gdp per capita ausrechnen und z(urban oder education shr) ausrechnen
@@ -136,6 +150,7 @@ toolRegression<-function(denominator,
xlab=xlab,
ylab=ylab,
x_log10=x_log10,
crossvalid=as.vector(crossvalid),
...
)

0 comments on commit a4c9f7c

Please sign in to comment.