Skip to content

Commit

Permalink
Merge pull request #283 from mrc-ide/mrc-4106
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz authored Apr 19, 2023
2 parents 8ad23bf + 705c77f commit fc13c9f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.4.5
Version: 1.4.6
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
44 changes: 42 additions & 2 deletions R/ir_parse_arrays.R
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,6 @@ ir_parse_arrays_check_rhs <- function(rhs, rank, int_arrays, include, eq,
invisible(NULL) # never return anything at all.
}


ir_parse_expr_lhs_check_index <- function(x) {
seen <- counter()
err <- collector()
Expand Down Expand Up @@ -850,9 +849,50 @@ ir_parse_expr_lhs_check_index <- function(x) {
}

value_max <- f(x, TRUE)
if (seen$get() > 0) { # check minimum branch

# If errors have already been spotted, don't check further;
# results/suggestions will be confusing.

if ((length(err$get()) == 0) && (seen$get() > 0)) { # check minimum branch
seen$reset()
value_min <- f(x, FALSE)
if (!is_call(x, ":")) {

if (is_call(x[[2]], ":")) {

# Handle 1:n+1, which is:-
# `+` (':', 1, n) 1 - so x[[2]][[1]] is ":" and...

lhs <- x[[2]][[2]]
rhs <- x[[2]][[3]]

# Suggest either 1:(n+1) or (1:n)+1

fix <- paste0(deparse_str(call(":", lhs,
call("(", call(as.character(x[[1]]), rhs, x[[3]])))), " or ",
deparse_str(call(as.character(x[[1]]),
call("(", call(":", lhs, rhs)), x[[3]])))

} else if ((length(x) > 2) && (is_call(x[[3]], ":"))) {

# Handle a+1:n, which is:-
# `+` a (`:` 1 n) - so x[[3]][[1]] is ":" and...

lhs <- x[[3]][[2]]
rhs <- x[[3]][[3]]

# Suggest either a+(1:n) or (a+1):n

fix <- paste0(deparse_str(call(as.character(x[[1]]), x[[2]],
call("(", call(":", lhs, rhs)))), " or ",
deparse_str(call(":",
call("(", call(as.character(x[[1]]), x[[2]], lhs)), rhs)))

} else {
fix <- "using parentheses"
}
err$add(sprintf("You are writing an ambiguous range, consider %s", fix))
}
} else {
value_min <- NULL
}
Expand Down
51 changes: 40 additions & 11 deletions tests/testthat/test-parse2-general.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,24 @@ test_that("expression parsing", {
expect_error(odin_parse_(quote(x[i] <- y[i])),
"Special index variable i may not be used on array lhs",
class = "odin_error")
})

expect_error(odin_parse_(quote(y[1:n + 1] <- 1)),
paste0("Invalid array use on lhs:\n",
"\t\tYou are writing an ambiguous range, ",
"consider 1:\\(n \\+ 1) or \\(1:n) \\+ 1*",
collapse = ""), class = "odin_error")

expect_error(odin_parse_(quote(y[1:n + 1 + 2] <- 1)),
paste0("Invalid array use on lhs:\n\t\tYou are ",
"writing an ambiguous range, consider using parentheses*",
collapse = ""), class = "odin_error")

expect_error(odin_parse_(quote(y[a + 1:n] <- 1)),
paste0("Invalid array use on lhs:\n\t\tYou are ",
"writing an ambiguous range, ",
"consider a \\+ \\(1:n) or \\(a \\+ 1):n*",
collapse = ""), class = "odin_error")
})

test_that("parse array indices", {
expect_error(odin_parse(
Expand Down Expand Up @@ -286,21 +302,34 @@ test_that("custom functions ignore arrays", {
})

test_that("lhs array checking", {
res <- ir_parse_expr_lhs_check_index(quote(a + (2:(n - 3) - 4) + z))
expect_true(res)
expect_equal(attr(res, "value_max"), quote(a + ((n - 3) - 4) + z))
expect_equal(attr(res, "value_min"), quote(a + (2 - 4) + z))

res <- ir_parse_expr_lhs_check_index(quote(a))
expect_true(res)
expect_equal(attr(res, "value_max"), quote(a))
expect_null(attr(res, "value_min"))

expect_false(ir_parse_expr_lhs_check_index(quote(a:b + c:d)))
expect_false(ir_parse_expr_lhs_check_index(quote(-(a:b)))) # nolint
expect_false(ir_parse_expr_lhs_check_index(quote((a:b):c)))
expect_false(ir_parse_expr_lhs_check_index(quote(c:(a:b))))
expect_false(ir_parse_expr_lhs_check_index(quote((-a))))
expect_single_error <- function(err, expected) {
expect_false(err)
count <- length(attr(err, "message"))
expect_equal(count, 1)
if (count == 1) {
expect_equal(expected, attr(err, "message"))
}
}

expect_single_error(ir_parse_expr_lhs_check_index(quote(a:b + c:d)),
"Multiple calls to ':' are not allowed")

expect_single_error(ir_parse_expr_lhs_check_index(quote(-(a:b))),
"Unary minus invalid in array calculation") # nolint

expect_single_error(ir_parse_expr_lhs_check_index(quote((a:b):c)),
"Multiple calls to ':' are not allowed")

expect_single_error(ir_parse_expr_lhs_check_index(quote(c:(a:b))),
"Multiple calls to ':' are not allowed")

expect_single_error(ir_parse_expr_lhs_check_index(quote((-a))),
"Unary minus invalid in array calculation")
})

test_that("sum rewriting", {
Expand Down

0 comments on commit fc13c9f

Please sign in to comment.