Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/c-wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ make_c_bridge <- function(fsub, strict = TRUE, headers = TRUE) {

closure <- fsub@closure
scope <- fsub@scope
uses_rng <- isTRUE(attr(scope, "uses_rng", TRUE))
uses_errors <- isTRUE(attr(scope, "uses_errors", TRUE))
uses_rng <- scope_uses_rng(scope)
uses_errors <- scope_uses_errors_flag(scope)

fsub_arg_names <- fsub@signature # arg names
closure_arg_names <- names(formals(closure)) %||% character()
Expand Down
21 changes: 21 additions & 0 deletions R/classes.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ Fortran := new_class(
properties = list(
value = NULL | Variable,

# Metadata flags used during compilation/lowering. Keep these as explicit
# properties rather than ad-hoc attributes so they are discoverable and
# consistently propagated with the Fortran object.
logical_booleanized = prop_bool(default = FALSE),
writes_to_dest = prop_bool(default = FALSE),

r = new_property(
# custom setter only to workaround https://github.com/RConsortium/S7/issues/511
NULL | class_language | class_atomic,
Expand Down Expand Up @@ -411,6 +417,21 @@ FortranSubroutine := new_class(
)
)

R2FHandler := new_class(
class_function,
properties = list(
dest_supported = prop_bool(default = FALSE),
dest_infer = new_property(NULL | class_function),
dest_infer_name = prop_string(default = NULL, allow_null = TRUE),
# When NULL, r2f will resolve the callable by name and use match.call().
# When FALSE, r2f will not attempt match.call().
match_fun = new_property(
NULL | class_function | class_logical,
default = NULL
)
)
)

try_prop <- function(object, name) S7::prop(object, name) %error% NULL

emit <- function(..., sep = "", end = "\n") cat(..., end, sep = sep)
Expand Down
9 changes: 6 additions & 3 deletions R/error-handling.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ scope_root_for_errors <- function(scope) {
return(scope)
}
while (
!identical(attr(scope, "kind", exact = TRUE), "subroutine") &&
!identical(scope_kind(scope), "subroutine") &&
inherits(parent.env(scope), "quickr_scope")
) {
scope <- parent.env(scope)
Expand All @@ -28,14 +28,17 @@ scope_root_for_errors <- function(scope) {
mark_scope_uses_errors <- function(scope) {
root <- scope_root_for_errors(scope)
if (inherits(root, "quickr_scope")) {
attr(root, "uses_errors") <- TRUE
scope_mark_uses_errors_flag(root)
}
invisible(TRUE)
}

scope_uses_errors <- function(scope) {
root <- scope_root_for_errors(scope)
isTRUE(attr(root, "uses_errors", TRUE))
if (!inherits(root, "quickr_scope")) {
return(FALSE)
}
scope_uses_errors_flag(root)
}

fortran_string_literal <- function(x) {
Expand Down
14 changes: 8 additions & 6 deletions R/manifest.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ block_tmp_allocatable <- function(
max_stack_elements = block_tmp_allocatable_threshold
) {
stopifnot(inherits(var, Variable))
if (!inherits(scope, "quickr_scope") || !identical(scope@kind, "block")) {
if (
!inherits(scope, "quickr_scope") || !identical(scope_kind(scope), "block")
) {
return(FALSE)
}
if (passes_as_scalar(var) || is.null(var@dims)) {
Expand Down Expand Up @@ -347,7 +349,7 @@ r2f.scope <- function(scope, include_errors = FALSE) {
local_allocs <- character()
vars <- lapply(vars, function(var) {
r_name <- var@r_name %||% var@name
intent_in <- r_name %in% names(formals(scope@closure))
intent_in <- r_name %in% names(formals(scope_closure(scope)))
intent_out <-
(r_name %in% return_var_names) ||
(intent_in && var@modified)
Expand Down Expand Up @@ -428,7 +430,7 @@ r2f.scope <- function(scope, include_errors = FALSE) {

# vars that will be visible in the C bridge, either as an input or output
non_local_var_names <- unique(c(
names(formals(scope@closure)),
names(formals(scope_closure(scope))),
return_var_names
))

Expand All @@ -437,9 +439,9 @@ r2f.scope <- function(scope, include_errors = FALSE) {
var <- scope[[name]]
lapply(var@dims, all.names, functions = FALSE, unique = TRUE)
}))) |>
setdiff(names(formals(scope@closure)))
if (length(names(formals(scope@closure)))) {
formal_vars <- mget(names(formals(scope@closure)), scope)
setdiff(names(formals(scope_closure(scope))))
if (length(names(formals(scope_closure(scope))))) {
formal_vars <- mget(names(formals(scope_closure(scope))), scope)
formal_fortran_names <- unique(map_chr(formal_vars, \(var) {
var@name %||% ""
}))
Expand Down
20 changes: 10 additions & 10 deletions R/parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,38 @@ get_pending_parallel <- function(scope) {
if (is.null(scope) || !inherits(scope, "quickr_scope")) {
return(NULL)
}
scope@pending_parallel
scope_get(scope, "pending_parallel")
}

has_pending_parallel <- function(scope) !is.null(get_pending_parallel(scope))

set_pending_parallel <- function(scope, decl) {
stopifnot(inherits(scope, "quickr_scope"), is.list(decl))
scope@pending_parallel <- decl
scope_set(scope, "pending_parallel", decl)
invisible(scope)
}

take_pending_parallel <- function(scope) {
if (is.null(scope) || !inherits(scope, "quickr_scope")) {
return(NULL)
}
decl <- scope@pending_parallel
scope@pending_parallel <- NULL
decl <- scope_get(scope, "pending_parallel")
scope_set(scope, "pending_parallel", NULL)
decl
}

mark_openmp_used <- function(scope) {
stopifnot(inherits(scope, "quickr_scope"))
root <- scope_root(scope)
attr(root, "uses_openmp") <- TRUE
scope_mark_uses_openmp_flag(root)
invisible(root)
}

scope_openmp_depth <- function(scope) {
if (!inherits(scope, "quickr_scope")) {
return(0L)
}
depth <- attr(scope, "openmp_depth", exact = TRUE)
depth <- scope_get(scope, "openmp_depth")
if (is.null(depth)) {
0L
} else {
Expand All @@ -51,9 +51,9 @@ enter_openmp_scope <- function(scope) {
if (!inherits(scope, "quickr_scope")) {
return(NULL)
}
previous_depth <- attr(scope, "openmp_depth", exact = TRUE)
previous_depth <- scope_get(scope, "openmp_depth")
depth <- scope_openmp_depth(scope)
attr(scope, "openmp_depth") <- depth + 1L
scope_set(scope, "openmp_depth", depth + 1L)
previous_depth
}

Expand All @@ -62,9 +62,9 @@ exit_openmp_scope <- function(scope, previous_depth) {
return(invisible(NULL))
}
if (is.null(previous_depth)) {
attr(scope, "openmp_depth") <- NULL
scope_set(scope, "openmp_depth", NULL)
} else {
attr(scope, "openmp_depth") <- as.integer(previous_depth)
scope_set(scope, "openmp_depth", as.integer(previous_depth))
}
invisible(TRUE)
}
Expand Down
2 changes: 1 addition & 1 deletion R/quick.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ compile <- function(fsub, build_dir = tempfile(paste0(fsub@name, "-build-"))) {
FLIBS <- FLIBS[nzchar(FLIBS)]
link_flags <- c(LAPACK_LIBS, BLAS_LIBS, FLIBS)

use_openmp <- isTRUE(attr(fsub@scope, "uses_openmp", exact = TRUE))
use_openmp <- scope_uses_openmp_flag(fsub@scope)
suppressWarnings({
env <- quickr_fcompiler_env(
build_dir = build_dir,
Expand Down
29 changes: 22 additions & 7 deletions R/r2f-aaa-registry.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,40 @@ register_r2f_handler <- function(
dest_infer = NULL,
match_fun = TRUE
) {
stopifnot(is.function(fun))

handler <- if (inherits(fun, R2FHandler)) fun else R2FHandler(fun)

if (!is.null(dest_supported)) {
attr(fun, "dest_supported") <- dest_supported
handler@dest_supported <- isTRUE(dest_supported)
}

if (!is.null(dest_infer)) {
attr(fun, "dest_infer") <- dest_infer
handler@dest_infer <- dest_infer
# covr rewrites function bindings in the namespace; resolving by name at call
# time ensures instrumented/rebound functions are respected. We keep the
# function object for robustness (e.g., anonymous functions) and additionally
# store the name when `dest_infer` is passed as a symbol.
dest_infer_expr <- substitute(dest_infer)
if (is.symbol(dest_infer_expr)) {
attr(fun, "dest_infer_name") <- as.character(dest_infer_expr)
handler@dest_infer_name <- as.character(dest_infer_expr)
} else {
handler@dest_infer_name <- NULL
}
}
if (!is.null(match_fun) && !isTRUE(match_fun)) {
attr(fun, "match.fun") <- match_fun

if (isTRUE(match_fun)) {
handler@match_fun <- NULL
} else if (
is.null(match_fun) || isFALSE(match_fun) || is.function(match_fun)
) {
handler@match_fun <- match_fun
} else {
stop("match_fun must be TRUE, FALSE, NULL, or a function")
}

for (nm in name) {
r2f_handlers[[nm]] <- fun
r2f_handlers[[nm]] <- handler
}
invisible(fun)
invisible(handler)
}
32 changes: 24 additions & 8 deletions R/r2f-aab-core.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ new_hoist <- function(scope) {

ensure_block_scope <- function() {
if (is.null(block_scope)) {
block_scope <<- scope@new_child("block")
block_scope <<- scope_new_child(scope, "block")
}
block_scope
}
Expand Down Expand Up @@ -81,7 +81,7 @@ logical_as_int_symbol <- function(var) {
}

scope_is_closure <- function(scope) {
inherits(scope, "quickr_scope") && identical(scope@kind, "closure")
inherits(scope, "quickr_scope") && identical(scope_kind(scope), "closure")
}

scope_fortran_names <- function(scope) {
Expand Down Expand Up @@ -154,7 +154,11 @@ lang2fortran <- r2f <- function(
{
handler <- get_r2f_handler(callable_unwrapped)

match.fun <- attr(handler, "match.fun", TRUE)
match.fun <- if (inherits(handler, R2FHandler)) {
handler@match_fun
} else {
attr(handler, "match.fun", TRUE)
}
if (is.null(match.fun)) {
match.fun <- get0(
callable_unwrapped,
Expand Down Expand Up @@ -232,7 +236,7 @@ lang2fortran <- r2f <- function(
val <- NULL
}
if (is.null(val) && inherits(scope, "quickr_scope")) {
closure <- scope@closure
closure <- scope_closure(scope)
arg_names <- if (is.null(closure)) NULL else names(formals(closure))
if (!is.null(arg_names) && r_name %in% arg_names) {
stop(
Expand All @@ -255,7 +259,7 @@ lang2fortran <- r2f <- function(
# and must be "booleanized" for Fortran logical operations.
s <- paste0("(", s, "/=0)")
out <- Fortran(s, value = if (inherits(val, Variable)) val else NULL)
attr(out, "logical_booleanized") <- TRUE
out@logical_booleanized <- TRUE
out
} else {
Fortran(s, value = if (inherits(val, Variable)) val else NULL)
Expand Down Expand Up @@ -375,7 +379,11 @@ dest_supported_for_call <- function(call) {
return(FALSE)
}
handler <- get0(as.character(unwrapped[[1L]]), r2f_handlers, inherits = FALSE)
isTRUE(attr(handler, "dest_supported", exact = TRUE))
if (inherits(handler, R2FHandler)) {
isTRUE(handler@dest_supported)
} else {
isTRUE(attr(handler, "dest_supported", exact = TRUE))
}
}

dest_infer_for_call <- function(call, scope) {
Expand All @@ -390,8 +398,16 @@ dest_infer_for_call <- function(call, scope) {
return(NULL)
}
handler <- get0(as.character(unwrapped[[1L]]), r2f_handlers, inherits = FALSE)
infer <- attr(handler, "dest_infer", exact = TRUE)
infer_name <- attr(handler, "dest_infer_name", exact = TRUE)
infer <- if (inherits(handler, R2FHandler)) {
handler@dest_infer
} else {
attr(handler, "dest_infer", exact = TRUE)
}
infer_name <- if (inherits(handler, R2FHandler)) {
handler@dest_infer_name
} else {
attr(handler, "dest_infer_name", exact = TRUE)
}

infer_fun <- NULL
if (is_string(infer_name)) {
Expand Down
Loading
Loading