Skip to content

Commit

Permalink
draft ggplot based sankey plot (#367)
Browse files Browse the repository at this point in the history
* draft ggplot based sankey plot

* remove unused code and make new sankey more robust

* prettify sankey plot

* improve colors

* replace old sankey figure with ggalluvial version

* clearer names for sankey columns

* update documentation

* rm dead code

* update image of README page
  • Loading branch information
jacobvjk authored Jan 8, 2025
1 parent d1b3a22 commit c3689e1
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 249 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ Imports:
cli (>= 3.2.0),
config,
dplyr,
ggalluvial,
ggplot2,
ggrepel,
glue,
networkD3,
r2dii.analysis (>= 0.3.0),
r2dii.data (>= 0.5.0),
r2dii.match (>= 0.3.0),
Expand All @@ -52,7 +53,6 @@ Imports:
rlang,
scales,
tidyr,
webshot,
yaml,
yesno
Depends:
Expand Down
55 changes: 16 additions & 39 deletions R/plot_aggregate_loanbooks.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,50 +171,27 @@ plot_aggregate_loanbooks <- function(config) {
na = ""
)

plot_sankey(
data_sankey_sector,
group_var = by_group,
save_png_to = path.expand(analysis_aggregated_dir),
png_name = glue::glue("plot_{output_file_sankey_sector}.png"),
nodes_order_from_data = TRUE
p_sankey <- plot_sankey(
data = data_sankey_sector,
y_axis = "loan_size_outstanding",
initial_node = by_group,
middle_node = "sector",
end_node = "is_aligned",
stratum = "is_aligned"
)
} else {
cli::cli_warn("Sankey plot (by sector) cannot be generated. Skipping!")
}

if (!is.null(company_aggregated_alignment_net)) {
data_sankey_company_sector <- prep_sankey(
company_aggregated_alignment_net,
region = "global",
year = start_year + time_frame_select,
group_var = by_group,
middle_node = "name_abcd",
middle_node2 = "sector"
ggplot2::ggsave(
plot = p_sankey,
filename = glue::glue("plot_{output_file_sankey_sector}.png"),
path = analysis_aggregated_dir,
width = 8,
height = 5,
dpi = 300,
units = "in",
)

if (is.null(by_group)) {
output_file_sankey_company_sector <- "sankey_company_sector"
} else {
output_file_sankey_company_sector <- glue::glue("sankey_company_sector_{by_group}")
}

data_sankey_company_sector %>%
readr::write_csv(
file = file.path(
analysis_aggregated_dir,
glue::glue("data_{output_file_sankey_company_sector}.csv")
),
na = ""
)

plot_sankey(
data_sankey_company_sector,
group_var = by_group,
save_png_to = path.expand(analysis_aggregated_dir),
png_name = glue::glue("plot_{output_file_sankey_company_sector}.png")
)
} else {
cli::cli_warn("Sankey plot (by sector and company) cannot be generated. Skipping!")
cli::cli_warn("Sankey plot (by sector) cannot be generated. Skipping!")
}

### scatter plot alignment by exposure and sector comparison----
Expand Down
211 changes: 52 additions & 159 deletions R/plot_sankey.R
Original file line number Diff line number Diff line change
@@ -1,176 +1,69 @@
#' Make a sankey plot
#'
#' @param data data.frame. Should have the same format as output of
#' `prep_sankey()` and contain columns: `"middle_node"`, optionally
#' `"middle_node2"`, `"is_aligned"`, `"loan_size_outstanding"`, and any column
#' implied by `group_var`.
#' @param group_var Character. Vector of length 1. Variable to group by.
#' @param capitalise_node_labels Logical. Flag indicating if node labels should
#' be converted into better looking capitalised form.
#' @param save_png_to Character. Path where the output in png format should be
#' saved
#' @param png_name Character. File name of the output.
#' @param nodes_order_from_data Logical. Flag indicating if nodes order should
#' be determined by an algorithm (in case of big datasets often results in a
#' better looking plot) or should they be ordered based on data.
#' `prep_sankey()` and contain columns: `"y_axis"`, `"initial_node"`,
#' `"middle_node"`, `"end_node"`, `"stratum"`, `"currency"`.
#' @param y_axis Character. Vector of length 1. Variable to determine the
#' vertical size of the ribbons, e.g. `"loan_size_outstanding"`.
#' @param initial_node Character. Vector of length 1. Variable to determine the
#' initial node of the sankey chart. Usually, this will be the groups by which
#' the loan books are aggregated.
#' @param middle_node Character. Vector of length 1. Variable to determine the
#' middle node of the sankey chart. Usually, this will be the PACTA sectors.
#' @param end_node Character. Vector of length 1. Variable to determine the
#' end node of the sankey chart. Usually, this will be a binary indicator of
#' alignment.
#' @param stratum Character. Vector of length 1. Variable to determine the
#' grouping and fill of the ribbons of the sankey chart. Usually, this will be
#' a binary indicator of alignment.
#'
#' @return NULL
#'
#' @noRd

plot_sankey <- function(data,
group_var,
capitalise_node_labels = TRUE,
save_png_to = NULL,
png_name = "sankey.png",
nodes_order_from_data = FALSE) {
if (!is.null(group_var)) {
if (!inherits(group_var, "character")) {
cli::cli_abort("{.arg group_var} must be of class {.cls character}")
}
if (!length(group_var) == 1) {
cli::cli_abort("{.arg group_var} must be of length 1")
}
} else {
data <- data %>%
dplyr::mutate(aggregate_loan_book = "Aggregate loan book")
group_var <- "aggregate_loan_book"
y_axis = "loan_size_outstanding",
initial_node,
middle_node = "sector",
end_node = "is_aligned",
stratum = "is_aligned") {
# since the initial node is the loan book aggregation, NULL grouping corresponds to the aggregate loan book
if (is.null(initial_node)) {
initial_node <- "aggregate_loan_book"
}

check_plot_sankey(
data = data,
group_var = group_var,
capitalise_node_labels = capitalise_node_labels
)

if (capitalise_node_labels) {
data_links <- data %>%
dplyr::mutate(
group_var = r2dii.plot::to_title(!!rlang::sym(group_var)),
middle_node = r2dii.plot::to_title(.data[["middle_node"]])
)
if ("middle_node2" %in% names(data_links)) {
data_links <- data_links %>%
dplyr::mutate(
middle_node2 = r2dii.plot::to_title(.data[["middle_node2"]])
)
}
} else {
data_links <- data
}
currency <- unique(data[["currency"]])

links_1 <- data_links %>%
dplyr::select(
source = .env[["group_var"]],
target = "middle_node",
value = "loan_size_outstanding",
group = "is_aligned"
p <- ggplot2::ggplot(
data = data,
ggplot2::aes(
axis1 = .data[["initial_node"]],
axis2 = .data[["middle_node"]],
axis3 = .data[["end_node"]],
y = .data[["loan_size_outstanding"]]
)

if ("middle_node2" %in% names(data_links)) {
links_2 <- data_links %>%
dplyr::select(
.env[["group_var"]],
source = "middle_node",
target = "middle_node2",
value = "loan_size_outstanding",
group = "is_aligned"
)

links_3 <- data_links %>%
dplyr::select(
.env[["group_var"]],
source = "middle_node2",
target = "is_aligned",
value = "loan_size_outstanding",
group = "is_aligned"
)

links <- dplyr::bind_rows(links_1, links_2, links_3)
} else {
links_2 <- data_links %>%
dplyr::select(
.env[["group_var"]],
source = "middle_node",
target = "is_aligned",
value = "loan_size_outstanding",
group = "is_aligned"
)

links <- dplyr::bind_rows(links_1, links_2)
}

links <- links %>%
dplyr::group_by(.data[["source"]], .data[["target"]], .data[["group"]]) %>%
dplyr::summarise(value = sum(.data[["value"]], na.rm = TRUE)) %>%
dplyr::ungroup() %>%
dplyr::arrange(.data[["source"]], .data[["group"]]) %>%
as.data.frame()

nodes <- data.frame(
name = unique(c(as.character(links$source), as.character(links$target)))
) %>%
dplyr::mutate(
group = dplyr::case_when(
.data[["name"]] %in% c("Aligned", "Not aligned", "Unknown") ~ .data[["name"]],
TRUE ~ "other"
)
) +
ggplot2::scale_y_continuous(labels = scales::comma) +
ggplot2::ylab(glue::glue("Financial exposure (in {currency})")) +
ggalluvial::geom_alluvium(ggplot2::aes(fill = .data[["is_aligned"]])) +
ggplot2::scale_fill_manual(
values = c("Aligned" = "green4", "Not aligned" = "red3", "Unknown" = "gray30")
) +
ggalluvial::geom_stratum(fill = "gray90", color = "gray50") +
ggrepel::geom_text_repel(
ggplot2::aes(label = ggplot2::after_stat(stratum)),
stat = ggalluvial::StatStratum, size = 4, direction = "y", nudge_x = .3
) +
r2dii.plot::theme_2dii() +
ggplot2::theme(
axis.title.x = ggplot2::element_blank(),
axis.text.x = ggplot2::element_blank(),
axis.ticks.x = ggplot2::element_blank()
) +
ggplot2::ggtitle(
"Sankey chart of counterparty alignment by financial exposure",
paste0("stratified by counterpaty alignment and ", middle_node)
)

my_color <- 'd3.scaleOrdinal() .domain(["Not aligned", "Aligned", "Unknown", "other"]) .range(["#e10000","#3d8c40", "#808080", "#808080"])'

links$IDsource <- match(links$source, nodes$name) - 1
links$IDtarget <- match(links$target, nodes$name) - 1

if (nodes_order_from_data) {
n_iter <- 0
} else {
n_iter <- 32 # sankeyNetwork() default
}

p <- networkD3::sankeyNetwork(
Links = links,
Nodes = nodes,
Source = "IDsource",
Target = "IDtarget",
Value = "value",
NodeID = "name",
colourScale = my_color,
LinkGroup = "group",
NodeGroup = "group",
fontSize = 14,
iterations = n_iter
)

if (!is.null(save_png_to)) {
# you save it as an html
temp_html <- tempfile(fileext = ".html")
networkD3::saveNetwork(p, temp_html)

if (webshot::is_phantomjs_installed()) {
file_name <- file.path(save_png_to, png_name)
# you convert it as png
webshot::webshot(temp_html, path.expand(file_name), vwidth = 1000, vheight = 900)
} else {
cli::cli_warn(
"In order to save the plot as PNG, you need to have {.pkg phantomjs}
installed. Please run {.run webshot::install_phantomjs()} if you don't
and try running the function again."
)
}
}
p
}

check_plot_sankey <- function(data,
group_var,
capitalise_node_labels) {
crucial_names <- c(group_var, "middle_node", "is_aligned", "loan_size_outstanding")
assert_no_missing_names(data, crucial_names)
if (!is.logical(capitalise_node_labels)) {
cli::cli_abort(c(
x = "`capitalise_node_labels` must have a {.cls logical} value.",
i = "capitalise_node_labels` contains the value{?s}: {.val {capitalise_node_labels}}."
))
}
}
Loading

0 comments on commit c3689e1

Please sign in to comment.