Saving and Reloading Fitted Kerasnip Workflows

Overview

Keras models are backed by Python objects managed by TensorFlow/JAX. These objects live in the current R session and are represented as external pointers (externalptr) that become invalid as soon as the session ends, or even within the same session after saveRDS() / readRDS().

kerasnip handles this transparently so that fitted workflows can be saved, reloaded, and used for prediction without any manual restoration steps.

Quick workflow example

Before discussing the details, here is the full persistence workflow:

library(kerasnip)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────── tidymodels 1.5.0 ──
#> ✔ broom        1.0.12     ✔ recipes      1.3.2 
#> ✔ dials        1.4.3      ✔ rsample      1.3.2 
#> ✔ dplyr        1.2.1      ✔ tailor       0.1.0 
#> ✔ ggplot2      4.0.3      ✔ tidyr        1.3.2 
#> ✔ infer        1.1.0      ✔ tune         2.1.0 
#> ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
#> ✔ parsnip      1.5.0      ✔ workflowsets 1.1.1 
#> ✔ purrr        1.2.2      ✔ yardstick    1.4.0
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
library(keras3)
#> 
#> Attaching package: 'keras3'
#> The following object is masked from 'package:yardstick':
#> 
#>     get_weights
#> The following object is masked from 'package:infer':
#> 
#>     generate

# 1. Define Layer Blocks (Required by kerasnip)
# The first block must initialize the sequential model
input_block <- function(model, input_shape) {
  keras_model_sequential(input_shape = input_shape)
}

# Hidden layer block
dense_block <- function(model, units = 32) {
  model |> layer_dense(units = units, activation = "relu")
}

# Output layer block (units = 1 for regression)
output_block <- function(model, num_classes) {
  model |> layer_dense(units = 1)
}

# 2. Generate the parsnip specification
create_keras_sequential_spec(
  model_name = "my_mlp",
  layer_blocks = list(
    input = input_block,
    hidden = dense_block,
    output = output_block
  ),
  mode = "regression"
)

# 3. Use the newly created 'my_mlp' function
mod_spec <- my_mlp(fit_epochs = 10) |> 
  set_engine("keras")

# 4. Standard tidymodels workflow
rec_spec <- recipe(mpg ~ ., data = mtcars) |> 
  step_normalize(all_predictors())

fit_wf <- workflow() |> 
  add_recipe(rec_spec) |> 
  add_model(mod_spec) |> 
  fit(data = mtcars)

# Predict
new_data <- mtcars[1:3, ]
predict(fit_wf, new_data)
#> 1/1 - 0s - 116ms/step
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1 1.00 
#> 2 0.956
#> 3 3.03

The first call to predict() detects that the Python pointer is invalid and restores the model from the stored bytes automatically.

What kerasnip does behind the scenes

kerasnip handles persistence automatically:

This means you can use the persistence strategy that best suits your workflow without any extra boilerplate.

Strategy 1: Plain saveRDS() / readRDS()

For most use cases: sharing a model file with a colleague, caching a fit between R sessions, or checkpointing during development; plain RDS is the simplest approach.

library(kerasnip)
library(workflows)
library(parsnip)
library(recipes)

# --- Save ---
saveRDS(fit_wf, "my_model.rds")

# --- Reload in the same or a new R session ---
library(kerasnip)
fit_wf <- readRDS("my_model.rds")

# predict() restores the Keras model from bytes automatically
predictions <- predict(fit_wf, new_data = new_data)
#> 1/1 - 0s - 118ms/step
predictions
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1 1.00 
#> 2 0.956
#> 3 3.03

There is nothing special to do after readRDS(). The first call to predict() detects the invalid pointer, restores the model from the stored bytes, and then proceeds normally.

Strategy 2: bundle / unbundle

The bundle package provides a standardized serialization interface used by vetiver, plumber, and other MLOps tools. It is the right choice when:

library(kerasnip)
library(bundle)
library(workflows)

# --- Save ---
bundled <- bundle(fit_wf)
saveRDS(bundled, "my_model_bundle.rds")

# --- Reload in any R session ---
library(kerasnip)
library(bundle)
bundled <- readRDS("my_model_bundle.rds")
fit_wf <- unbundle(bundled)
predictions <- predict(fit_wf, new_data = new_data)
#> 1/1 - 0s - 110ms/step
predictions
#> # A tibble: 3 × 1
#>   .pred
#>   <dbl>
#> 1 1.00 
#> 2 0.956
#> 3 3.03

Comparison

saveRDS / readRDS bundle / unbundle
Works across sessions
Works across machines ✅ (same R library)
vetiver / Docker compatible
Extra dependency needed bundle package
Code complexity Minimal Minimal

What happens under the hood

When kerasnip fits a model, the generic fit function calls keras_model_to_bytes(), which writes the model to a temporary .keras file using keras3::save_model() and reads the bytes back into R:

# Simplified version of what happens inside generic_sequential_fit()
keras_bytes <- keras_model_to_bytes(model)
# keras_bytes is a raw vector stored in object$fit$keras_bytes

When predict() is called on a reloaded object, predict.kerasnip_model_fit() runs:

# Simplified version of predict.kerasnip_model_fit()
if (!is.null(object$fit$keras_bytes)) {
  is_valid <- tryCatch(
    {
      reticulate::py_validate_xptr(object$fit$fit)
      TRUE
    },
    error = function(e) FALSE
  )
  if (!is_valid) {
    object$fit$fit <- keras_model_from_bytes(object$fit$keras_bytes)
  }
}

If keras_model_to_bytes() fails (e.g. if the model was compiled with a non-serialisable custom object), a warning is issued at fit time and keras_bytes is set to NULL. In that case, predict() after reload will fail with a clear error.