Skip to content

Commit

Permalink
Improve denoising AE performance
Browse files Browse the repository at this point in the history
  • Loading branch information
fdavidcl committed Feb 18, 2019
1 parent 945e50a commit cceb17c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
42 changes: 29 additions & 13 deletions R/autoencoder.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,40 @@ train.ruta_autoencoder <- function(
metrics = metrics
)

input_data <- if (is.null(learner$filter)) {
data
} else {
apply_filter(learner$filter, data)
}
# input_data <- if (is.null(learner$filter)) {
# data
# } else {
# apply_filter(learner$filter, data)
# }

if (!is.null(validation_data)) {
validation_data <- list(validation_data, validation_data)
}

keras::fit(
learner$models$autoencoder,
x = input_data,
y = data,
epochs = epochs,
validation_data = validation_data,
...
)
if (is.null(learner$filter)) {
keras::fit(
learner$models$autoencoder,
x = data,
y = data,
epochs = epochs,
validation_data = validation_data,
...
)
} else {
batch_size <- list(...)$batch_size
if (is.null(batch_size)) batch_size <- 32

keras::fit_generator(
learner$models$autoencoder,
to_keras(learner$filter, data, batch_size),
steps_per_epoch = ceiling(dim(data)[1] / batch_size),
epochs = epochs,
validation_data = validation_data,
...
)
}



invisible(learner)
}
Expand Down
22 changes: 22 additions & 0 deletions R/filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,25 @@ apply_filter.ruta_noise_cauchy <- function(filter, data, ...) {

data + term
}

#' @import R.utils
#' @param data
#' @param batch_size
to_keras.ruta_filter <- function(x, data, batch_size, ...) {
limit <- dim(data)[1]
order <- sample.int(limit)
start <- 1
function() {
if (start + batch_size > limit) {
idx <- order[start:limit]
order <- sample.int(limit)
start <- 1
} else {
idx <- order[start:(start + batch_size - 1)]
start <- start + batch_size
}
original <- R.utils::extract(data, "1" = idx)
noisy <- apply_filter(x, original)
list(noisy, original)
}
}

0 comments on commit cceb17c

Please sign in to comment.