Skip to content

Commit

Permalink
fix: recruitment log_dev setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Bai-Li-NOAA committed Oct 30, 2024
1 parent 393cda1 commit 7c7c5f7
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 83 deletions.
114 changes: 71 additions & 43 deletions R/better_initialize_modules.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ initialize_module <- function(input, data, module_name) {
}

# Define module class and fields
if (module_name == "population") {
module_class_name <- "Population"
} else if (!(module_name %in% names(input[["module_list"]])) &
(names(module_name) == "selectivity")){
module_class_name <-
module_class_name <-
if (module_name == "population") {
"Population"
} else if (!(module_name %in% names(input[["module_list"]])) &
(names(module_name) == "selectivity")) {
input[["module_list"]][["fleets"]][[module_name]][[names(module_name)]][["form"]]
} else if (!(module_name %in% names(input[["module_list"]])) &
names(module_name) == "Fleet"){
module_class_name <- "Fleet"
} else {
module_class_name <- input[["module_list"]][[module_name]]$form
}
} else if (!(module_name %in% names(input[["module_list"]])) &
names(module_name) == "Fleet") {
"Fleet"
} else {
input[["module_list"]][[module_name]]$form
}

module_class <- get(module_class_name)
module_fields <- names(module_class@fields)
module <- new(module_class)
module_input <- input[["parameter_input_list"]][[module_name]]

if (module_class_name == "Fleet"){
if (module_class_name == "Fleet") {
module_fields <- setdiff(module_fields, c(
"log_expected_index",
"proportion_catch_numbers_at_age"
Expand All @@ -59,23 +59,50 @@ initialize_module <- function(input, data, module_name) {
# - Reconsider exposing `log_expected_index` and
# `proportion_catch_numbers_at_age` to users. Their IDs are linked with
# index and agecomp distributions. No input values are required.

non_standard_field <- c(
"ages", "nages", "proportion_female", "estimate_prop_female",
"nyears", "nseasons", "nfleets", "estimate_log_devs", "weights",
"is_survey", "estimate_q", "random_q"
)
for (field in module_fields) {
module[[field]] <- switch(
field,
"ages" = get_data_slot(field, data),
"nages" = get_data_slot(field, data),
"proportion_female" = numeric(0),
"estimate_prop_female" = FALSE,
"nyears" = get_data_slot(field, data),
"nseasons" = 1,
"nfleets" = length(input[["module_list"]][["fleets"]]),
"estimate_log_devs" = module_input[[paste0(module_class_name, ".estimate_log_devs")]],
"weights" = m_weight_at_age(data),
"is_survey" = module_input[[paste0(module_class_name, ".is_survey")]],
"estimate_q" = module_input[[paste0(module_class_name, ".log_q.estimated")]],
"random_q" = FALSE,
set_param_vector(param_vector_name = field, module_input = module_input)
)
if (field %in% non_standard_field) {
# TODO: reorder the list alphabetically
module[[field]] <- switch(field,
"ages" = get_data_slot(field, data),
"nages" = get_data_slot(field, data),
"proportion_female" = numeric(0),
"estimate_prop_female" = FALSE,
"nyears" = get_data_slot(field, data),
"nseasons" = 1,
"nfleets" = length(input[["module_list"]][["fleets"]]),
"estimate_log_devs" = module_input[[paste0(module_class_name, ".estimate_log_devs")]],
"weights" = m_weight_at_age(data),
"is_survey" = module_input[[paste0(module_class_name, ".is_survey")]],
"estimate_q" = module_input[[paste0(module_class_name, ".log_q.estimated")]],
"random_q" = FALSE
)
} else {
field_value_name <- grep(paste0(field, ".value"), names(module_input), value = TRUE)
field_estimated_name <- grep(paste0(field, ".estimated"), names(module_input), value = TRUE)

# Check for the presence of value and estimation information
if (length(field_value_name) == 0 || length(field_estimated_name) == 0) {
cli::cli_abort(c("Missing value or estimation information for field {field}."))
}

# Extract the value of the parameter vector
field_value <- module_input[[field_value_name]]

if (length(field_value) > 1) module[[field]]$resize(length(field_value))

for (i in seq_along(field_value)) {
module[[field]][i][["value"]] <- field_value[i]
}

# Set the estimation information for the parameter vector
module[[field]]$set_all_estimable(module_input[[field_estimated_name]])
}
}

return(module)
Expand All @@ -88,15 +115,16 @@ initialize_module <- function(input, data, module_name) {
#' @export
initialize_distribution <- function(module_input, distribution_name,
distribution_type, linked_ids) {

# Check if distribution_name is provided
if (is.null(distribution_name)) return(NULL)
if (is.null(distribution_name)) {
return(NULL)
}

# Get distribution value and initialize the module
distribution_value <- get(distribution_name)
distribution_module <- new(distribution_value)
distribution_fields <- names(distribution_value@fields)
if (distribution_type == "data"){
if (distribution_type == "data") {
distribution_fields <- setdiff(distribution_fields, c(
"expected_values",
"x",
Expand All @@ -111,18 +139,18 @@ initialize_distribution <- function(module_input, distribution_name,
}

switch(distribution_type,
"data" = {
# Data distribution initialization
distribution_module$set_observed_data(linked_ids["data_link"])
distribution_module$set_distribution_links(distribution_type, linked_ids["fleet_link"])
},
"process" = {
# Process distribution initialization
distribution_module$set_distribution_links("random_effects", linked_ids)
},
{
cli::cli_abort("Unsupported distribution type: {.val {distribution_type}}. Please use 'data' or 'process'.")
}
"data" = {
# Data distribution initialization
distribution_module$set_observed_data(linked_ids["data_link"])
distribution_module$set_distribution_links(distribution_type, linked_ids["fleet_link"])
},
"process" = {
# Process distribution initialization
distribution_module$set_distribution_links("random_effects", linked_ids)
},
{
cli::cli_abort("Unsupported distribution type: {.val {distribution_type}}. Please use 'data' or 'process'.")
}
)

# Final message to confirm success
Expand Down
72 changes: 36 additions & 36 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,39 @@ get_data_slot <- function(field_name, data) {
return(output)
}

#' Set Parameter Vector Values from Module Input.
#' @export
set_param_vector <- function(param_vector_name, module_input) {
# Retrieve the names of the value and estimated slots for the parameter vector
param_vector_value_name <- grep(paste0(param_vector_name, ".value"), names(module_input), value = TRUE)
param_vector_estimated_name <- grep(paste0(param_vector_name, ".estimated"), names(module_input), value = TRUE)

# Check for the presence of value and estimation information
if (length(param_vector_value_name) == 0 || length(param_vector_estimated_name) == 0) {
cli::cli_abort(c("Missing value or estimation information for field {param_vector_name}."))
}

# Extract the value of the parameter vector
param_vector_value <- module_input[[param_vector_value_name]]

# Create a new ParameterVector object
param_vector_module <- methods::new(ParameterVector, param_vector_value, length(param_vector_value))

# Set the estimation information for the parameter vector
param_vector_module$set_all_estimable(module_input[[param_vector_estimated_name]])

return(param_vector_module)
}

#' Get Rcpp Modules
#' @export
get_rcpp_modules <- function(objs){

# Filter objects to find those that inherit from classes starting with "Rcpp_"
rcpp_names <- names(Filter(function(i) {
any(startsWith(class(i), "Rcpp_"))
}, objs))

rcpp_objs <- objs[rcpp_names]

}
#' #' Set Parameter Vector Values from Module Input.
#' #' @export
#' set_param_vector <- function(param_vector_name, module_input) {
#' # Retrieve the names of the value and estimated slots for the parameter vector
#' param_vector_value_name <- grep(paste0(param_vector_name, ".value"), names(module_input), value = TRUE)
#' param_vector_estimated_name <- grep(paste0(param_vector_name, ".estimated"), names(module_input), value = TRUE)
#'
#' # Check for the presence of value and estimation information
#' if (length(param_vector_value_name) == 0 || length(param_vector_estimated_name) == 0) {
#' cli::cli_abort(c("Missing value or estimation information for field {param_vector_name}."))
#' }
#'
#' # Extract the value of the parameter vector
#' param_vector_value <- module_input[[param_vector_value_name]]
#'
#' # Create a new ParameterVector object
#' # param_vector_module <- methods::new(ParameterVector, param_vector_value, length(param_vector_value))
#'
#' # Set the estimation information for the parameter vector
#' param_vector_module$set_all_estimable(module_input[[param_vector_estimated_name]])
#'
#' return(param_vector_module)
#' }

#' #' Get Rcpp Modules
#' #' @export
#' get_rcpp_modules <- function(objs){
#'
#' # Filter objects to find those that inherit from classes starting with "Rcpp_"
#' rcpp_names <- names(Filter(function(i) {
#' any(startsWith(class(i), "Rcpp_"))
#' }, objs))
#'
#' rcpp_objs <- objs[rcpp_names]
#'
#' }
9 changes: 5 additions & 4 deletions tests/testthat/helper-integration-tests-setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ setup_and_run_FIMS_without_wrappers <- function(iter_id,
fishing_fleet <- new(Fleet)
fishing_fleet$nages <- om_input$nages
fishing_fleet$nyears <- om_input$nyr
fishing_fleet$log_Fmort$resize(om_input$nyr)
fishing_fleet$log_Fmort$resize(om_input$nyr)
for(y in 1:om_input$nyr){
fishing_fleet$log_Fmort[y]$value <- log(om_output$f[y])
}
Expand Down Expand Up @@ -184,12 +184,13 @@ setup_and_run_FIMS_without_wrappers <- function(iter_id,
# alternative setting: recruitment$log_devs <- rep(0, length(om_input$logR.resid))
recruitment$log_devs$resize(om_input$nyr-1)
for(y in 1:(om_input$nyr-1)){
recruitment$log_devs[y]$value <- om_input$logR.resid[y]
recruitment$log_devs[y]$value <- om_input$logR.resid[y+1]
}
recruitment_distribution <- new(TMBDnormDistribution)
# set up logR_sd using the normal log_sd parameter
# logR_sd is NOT logged. It needs to enter the model logged b/c the exp() is
# taken before the likelihood calculation
recruitment_distribution$log_sd <- new(ParameterVector, 1)
recruitment_distribution$log_sd[1]$value <- log(om_input$logR_sd)
recruitment_distribution$log_sd[1]$estimated <- FALSE
recruitment_distribution$x$resize(om_input$nyr - 1)
Expand Down Expand Up @@ -354,8 +355,8 @@ setup_and_run_FIMS_with_wrappers <- function(iter_id,
# Fleet
# Create the fishing fleet
fishing_fleet_selectivity <- new(LogisticSelectivity)
fishing_fleet_selectivity$inflection_point[1]$value <- 2.0
fishing_fleet_selectivity$slope[1]$value <- om_input$sel_fleet$fleet1$slope.sel1
fishing_fleet_selectivity$inflection_point[1]$value <- 2.0
fishing_fleet_selectivity$slope[1]$value <- om_input$sel_fleet$fleet1$slope.sel1

# fishing_fleet_selectivity$inflection_point[1]$value <- om_input$sel_fleet$fleet1$A50.sel1
# fishing_fleet_selectivity$inflection_point[1]$is_random_effect <- FALSE
Expand Down

0 comments on commit 7c7c5f7

Please sign in to comment.