From 7c7c5f72c81518fd169e4f93dca30929d8626ee7 Mon Sep 17 00:00:00 2001 From: Bai-Li-NOAA Date: Wed, 30 Oct 2024 12:01:26 -0400 Subject: [PATCH] fix: recruitment log_dev setup --- R/better_initialize_modules.R | 114 +++++++++++------- R/utils.R | 72 +++++------ .../testthat/helper-integration-tests-setup.R | 9 +- 3 files changed, 112 insertions(+), 83 deletions(-) diff --git a/R/better_initialize_modules.R b/R/better_initialize_modules.R index ce432c5f..d9ac04f8 100644 --- a/R/better_initialize_modules.R +++ b/R/better_initialize_modules.R @@ -15,25 +15,25 @@ initialize_module <- function(input, data, module_name) { } # Define module class and fields - if (module_name == "population") { - module_class_name <- "Population" - } else if (!(module_name %in% names(input[["module_list"]])) & - (names(module_name) == "selectivity")){ - module_class_name <- + module_class_name <- + if (module_name == "population") { + "Population" + } else if (!(module_name %in% names(input[["module_list"]])) & + (names(module_name) == "selectivity")) { input[["module_list"]][["fleets"]][[module_name]][[names(module_name)]][["form"]] - } else if (!(module_name %in% names(input[["module_list"]])) & - names(module_name) == "Fleet"){ - module_class_name <- "Fleet" - } else { - module_class_name <- input[["module_list"]][[module_name]]$form - } + } else if (!(module_name %in% names(input[["module_list"]])) & + names(module_name) == "Fleet") { + "Fleet" + } else { + input[["module_list"]][[module_name]]$form + } module_class <- get(module_class_name) module_fields <- names(module_class@fields) module <- new(module_class) module_input <- input[["parameter_input_list"]][[module_name]] - if (module_class_name == "Fleet"){ + if (module_class_name == "Fleet") { module_fields <- setdiff(module_fields, c( "log_expected_index", "proportion_catch_numbers_at_age" @@ -59,23 +59,50 @@ initialize_module <- function(input, data, module_name) { # - Reconsider exposing `log_expected_index` and # `proportion_catch_numbers_at_age` to users. Their IDs are linked with # index and agecomp distributions. No input values are required. + + non_standard_field <- c( + "ages", "nages", "proportion_female", "estimate_prop_female", + "nyears", "nseasons", "nfleets", "estimate_log_devs", "weights", + "is_survey", "estimate_q", "random_q" + ) for (field in module_fields) { - module[[field]] <- switch( - field, - "ages" = get_data_slot(field, data), - "nages" = get_data_slot(field, data), - "proportion_female" = numeric(0), - "estimate_prop_female" = FALSE, - "nyears" = get_data_slot(field, data), - "nseasons" = 1, - "nfleets" = length(input[["module_list"]][["fleets"]]), - "estimate_log_devs" = module_input[[paste0(module_class_name, ".estimate_log_devs")]], - "weights" = m_weight_at_age(data), - "is_survey" = module_input[[paste0(module_class_name, ".is_survey")]], - "estimate_q" = module_input[[paste0(module_class_name, ".log_q.estimated")]], - "random_q" = FALSE, - set_param_vector(param_vector_name = field, module_input = module_input) - ) + if (field %in% non_standard_field) { + # TODO: reorder the list alphabetically + module[[field]] <- switch(field, + "ages" = get_data_slot(field, data), + "nages" = get_data_slot(field, data), + "proportion_female" = numeric(0), + "estimate_prop_female" = FALSE, + "nyears" = get_data_slot(field, data), + "nseasons" = 1, + "nfleets" = length(input[["module_list"]][["fleets"]]), + "estimate_log_devs" = module_input[[paste0(module_class_name, ".estimate_log_devs")]], + "weights" = m_weight_at_age(data), + "is_survey" = module_input[[paste0(module_class_name, ".is_survey")]], + "estimate_q" = module_input[[paste0(module_class_name, ".log_q.estimated")]], + "random_q" = FALSE + ) + } else { + field_value_name <- grep(paste0(field, ".value"), names(module_input), value = TRUE) + field_estimated_name <- grep(paste0(field, ".estimated"), names(module_input), value = TRUE) + + # Check for the presence of value and estimation information + if (length(field_value_name) == 0 || length(field_estimated_name) == 0) { + cli::cli_abort(c("Missing value or estimation information for field {field}.")) + } + + # Extract the value of the parameter vector + field_value <- module_input[[field_value_name]] + + if (length(field_value) > 1) module[[field]]$resize(length(field_value)) + + for (i in seq_along(field_value)) { + module[[field]][i][["value"]] <- field_value[i] + } + + # Set the estimation information for the parameter vector + module[[field]]$set_all_estimable(module_input[[field_estimated_name]]) + } } return(module) @@ -88,15 +115,16 @@ initialize_module <- function(input, data, module_name) { #' @export initialize_distribution <- function(module_input, distribution_name, distribution_type, linked_ids) { - # Check if distribution_name is provided - if (is.null(distribution_name)) return(NULL) + if (is.null(distribution_name)) { + return(NULL) + } # Get distribution value and initialize the module distribution_value <- get(distribution_name) distribution_module <- new(distribution_value) distribution_fields <- names(distribution_value@fields) - if (distribution_type == "data"){ + if (distribution_type == "data") { distribution_fields <- setdiff(distribution_fields, c( "expected_values", "x", @@ -111,18 +139,18 @@ initialize_distribution <- function(module_input, distribution_name, } switch(distribution_type, - "data" = { - # Data distribution initialization - distribution_module$set_observed_data(linked_ids["data_link"]) - distribution_module$set_distribution_links(distribution_type, linked_ids["fleet_link"]) - }, - "process" = { - # Process distribution initialization - distribution_module$set_distribution_links("random_effects", linked_ids) - }, - { - cli::cli_abort("Unsupported distribution type: {.val {distribution_type}}. Please use 'data' or 'process'.") - } + "data" = { + # Data distribution initialization + distribution_module$set_observed_data(linked_ids["data_link"]) + distribution_module$set_distribution_links(distribution_type, linked_ids["fleet_link"]) + }, + "process" = { + # Process distribution initialization + distribution_module$set_distribution_links("random_effects", linked_ids) + }, + { + cli::cli_abort("Unsupported distribution type: {.val {distribution_type}}. Please use 'data' or 'process'.") + } ) # Final message to confirm success diff --git a/R/utils.R b/R/utils.R index 5ca9d135..ccf93148 100644 --- a/R/utils.R +++ b/R/utils.R @@ -21,39 +21,39 @@ get_data_slot <- function(field_name, data) { return(output) } -#' Set Parameter Vector Values from Module Input. -#' @export -set_param_vector <- function(param_vector_name, module_input) { - # Retrieve the names of the value and estimated slots for the parameter vector - param_vector_value_name <- grep(paste0(param_vector_name, ".value"), names(module_input), value = TRUE) - param_vector_estimated_name <- grep(paste0(param_vector_name, ".estimated"), names(module_input), value = TRUE) - - # Check for the presence of value and estimation information - if (length(param_vector_value_name) == 0 || length(param_vector_estimated_name) == 0) { - cli::cli_abort(c("Missing value or estimation information for field {param_vector_name}.")) - } - - # Extract the value of the parameter vector - param_vector_value <- module_input[[param_vector_value_name]] - - # Create a new ParameterVector object - param_vector_module <- methods::new(ParameterVector, param_vector_value, length(param_vector_value)) - - # Set the estimation information for the parameter vector - param_vector_module$set_all_estimable(module_input[[param_vector_estimated_name]]) - - return(param_vector_module) -} - -#' Get Rcpp Modules -#' @export -get_rcpp_modules <- function(objs){ - - # Filter objects to find those that inherit from classes starting with "Rcpp_" - rcpp_names <- names(Filter(function(i) { - any(startsWith(class(i), "Rcpp_")) - }, objs)) - - rcpp_objs <- objs[rcpp_names] - -} +#' #' Set Parameter Vector Values from Module Input. +#' #' @export +#' set_param_vector <- function(param_vector_name, module_input) { +#' # Retrieve the names of the value and estimated slots for the parameter vector +#' param_vector_value_name <- grep(paste0(param_vector_name, ".value"), names(module_input), value = TRUE) +#' param_vector_estimated_name <- grep(paste0(param_vector_name, ".estimated"), names(module_input), value = TRUE) +#' +#' # Check for the presence of value and estimation information +#' if (length(param_vector_value_name) == 0 || length(param_vector_estimated_name) == 0) { +#' cli::cli_abort(c("Missing value or estimation information for field {param_vector_name}.")) +#' } +#' +#' # Extract the value of the parameter vector +#' param_vector_value <- module_input[[param_vector_value_name]] +#' +#' # Create a new ParameterVector object +#' # param_vector_module <- methods::new(ParameterVector, param_vector_value, length(param_vector_value)) +#' +#' # Set the estimation information for the parameter vector +#' param_vector_module$set_all_estimable(module_input[[param_vector_estimated_name]]) +#' +#' return(param_vector_module) +#' } + +#' #' Get Rcpp Modules +#' #' @export +#' get_rcpp_modules <- function(objs){ +#' +#' # Filter objects to find those that inherit from classes starting with "Rcpp_" +#' rcpp_names <- names(Filter(function(i) { +#' any(startsWith(class(i), "Rcpp_")) +#' }, objs)) +#' +#' rcpp_objs <- objs[rcpp_names] +#' +#' } diff --git a/tests/testthat/helper-integration-tests-setup.R b/tests/testthat/helper-integration-tests-setup.R index d0aab26d..dc91cb3b 100644 --- a/tests/testthat/helper-integration-tests-setup.R +++ b/tests/testthat/helper-integration-tests-setup.R @@ -81,7 +81,7 @@ setup_and_run_FIMS_without_wrappers <- function(iter_id, fishing_fleet <- new(Fleet) fishing_fleet$nages <- om_input$nages fishing_fleet$nyears <- om_input$nyr - fishing_fleet$log_Fmort$resize(om_input$nyr) + fishing_fleet$log_Fmort$resize(om_input$nyr) for(y in 1:om_input$nyr){ fishing_fleet$log_Fmort[y]$value <- log(om_output$f[y]) } @@ -184,12 +184,13 @@ setup_and_run_FIMS_without_wrappers <- function(iter_id, # alternative setting: recruitment$log_devs <- rep(0, length(om_input$logR.resid)) recruitment$log_devs$resize(om_input$nyr-1) for(y in 1:(om_input$nyr-1)){ - recruitment$log_devs[y]$value <- om_input$logR.resid[y] + recruitment$log_devs[y]$value <- om_input$logR.resid[y+1] } recruitment_distribution <- new(TMBDnormDistribution) # set up logR_sd using the normal log_sd parameter # logR_sd is NOT logged. It needs to enter the model logged b/c the exp() is # taken before the likelihood calculation + recruitment_distribution$log_sd <- new(ParameterVector, 1) recruitment_distribution$log_sd[1]$value <- log(om_input$logR_sd) recruitment_distribution$log_sd[1]$estimated <- FALSE recruitment_distribution$x$resize(om_input$nyr - 1) @@ -354,8 +355,8 @@ setup_and_run_FIMS_with_wrappers <- function(iter_id, # Fleet # Create the fishing fleet fishing_fleet_selectivity <- new(LogisticSelectivity) - fishing_fleet_selectivity$inflection_point[1]$value <- 2.0 - fishing_fleet_selectivity$slope[1]$value <- om_input$sel_fleet$fleet1$slope.sel1 + fishing_fleet_selectivity$inflection_point[1]$value <- 2.0 + fishing_fleet_selectivity$slope[1]$value <- om_input$sel_fleet$fleet1$slope.sel1 # fishing_fleet_selectivity$inflection_point[1]$value <- om_input$sel_fleet$fleet1$A50.sel1 # fishing_fleet_selectivity$inflection_point[1]$is_random_effect <- FALSE