| Title: | Cellular Dynamics Inference Using Neural ODE |
|---|---|
| Description: | An R implementation for single-cell trajectory inference using Variational Autoencoder (VAE) and Neural Ordinary Differential Equations (Neural ODE). CellODE automatically infers cellular dynamics from single-cell RNA sequencing data, providing pseudotime estimation, latent space representation, and vector field analysis. The package is designed for seamless integration with Seurat objects and supports both Seurat V4 and V5. |
| Authors: | Zaoqu Liu [aut, cre] (ORCID: <https://orcid.org/0000-0002-0000-0000>) |
| Maintainer: | Zaoqu Liu <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 1.0.0 |
| Built: | 2026-05-26 05:56:07 UTC |
| Source: | https://github.com/Zaoqu-Liu/CellODE |
CellODE is an R package for single-cell trajectory inference using Variational Autoencoder (VAE) and Neural Ordinary Differential Equations (Neural ODE). The package automatically infers cellular dynamics from single-cell RNA sequencing data, providing:
Pseudotime estimation
Latent space representation
Vector field analysis
Trajectory visualization
CellODE is designed for seamless integration with Seurat objects (supporting both V4 and V5). The core model (TNODE) combines a VAE for dimensionality reduction with a Neural ODE for modeling continuous cellular dynamics.
The main workflow involves:
Creating a Trainer object with your Seurat data
Training the model using trainer$train()
Extracting pseudotime using trainer$get_time()
Getting latent representations using trainer$get_latentsp()
Visualizing results with plot_pseudotime() and plot_vector_field()
Zaoqu Liu [email protected]
Li, S. et al. (2023). scTour: A deep learning architecture for robust inference and accurate prediction of cellular dynamics. bioRxiv. https://doi.org/10.1101/2023.01.13.523988
Useful links:
Report bugs at https://github.com/Zaoqu-Liu/CellODE/issues
Calculate cosine similarity between vector field and cell-neighbor latent state difference. This function matches scTour's cosine_similarity exactly. The calculation borrows the ideas from scvelo.
Uses optimized C++ implementation for performance on large datasets.
cosine_similarity( zs, vf, reverse = FALSE, n_neigh = 20, t = NULL, var_transform = FALSE, neighbor_indices = NULL )cosine_similarity( zs, vf, reverse = FALSE, n_neigh = 20, t = NULL, var_transform = FALSE, neighbor_indices = NULL )
zs |
Latent space matrix (cells x latent_dim) |
vf |
Vector field matrix (cells x latent_dim) |
reverse |
Whether to reverse vector field direction (default: FALSE) |
n_neigh |
Number of neighbors (default: 20) |
t |
Time vector for time-aware neighbors (default: NULL) |
var_transform |
Variance-stabilizing transformation (default: FALSE) |
neighbor_indices |
Pre-computed neighbor indices matrix (optional) |
Sparse matrix of cosine similarities
Extract expression matrix from Seurat object, supporting both V4 and V5.
extract_expression(seurat_obj, assay = "RNA", slot = "counts")extract_expression(seurat_obj, assay = "RNA", slot = "counts")
seurat_obj |
Seurat object |
assay |
Assay name (default: "RNA") |
slot |
Slot/layer name (default: "counts") |
Dense matrix (cells x genes)
Calculate L2 norm along specified axis. Handles both dense and sparse matrices.
l2_norm(x, axis = -1)l2_norm(x, axis = -1)
x |
Input matrix |
axis |
Axis along which to compute norm (default: -1 for last axis) |
Vector of L2 norms
Load a trained CellODE model from file.
load_model(path, seurat_obj)load_model(path, seurat_obj)
path |
Path to saved model (without extension) |
seurat_obj |
Seurat object for training data |
Trainer object with loaded model
Calculate log probability under negative binomial distribution. Adapted from scvi-tools.
log_nb(x, mu, theta, eps = 1e-08)log_nb(x, mu, theta, eps = 1e-08)
x |
Observed counts |
mu |
Mean parameter |
theta |
Dispersion parameter |
eps |
Small constant for numerical stability |
Tensor of log probabilities
Gayoso et al. scvi-tools. https://github.com/YosefLab/scvi-tools
Calculate log probability under zero-inflated negative binomial. Adapted from scvi-tools.
log_zinb(x, mu, theta, pi, eps = 1e-08)log_zinb(x, mu, theta, pi, eps = 1e-08)
x |
Observed counts |
mu |
Mean parameter |
theta |
Dispersion parameter |
pi |
Dropout probability (logit scale) |
eps |
Small constant for numerical stability |
Tensor of log probabilities
Gayoso et al. scvi-tools. https://github.com/YosefLab/scvi-tools
Create a torch dataset from expression matrix.
MakeDataset(X, loss_mode = "nb")MakeDataset(X, loss_mode = "nb")
X |
Expression matrix (cells x genes) |
loss_mode |
Loss mode: "mse", "nb", or "zinb" |
torch::dataset object
VAE components and Latent ODE function using torch nn_module
Calculate KL divergence between two normal distributions. This is the standard formulation from torchdiffeq.
normal_kl(mu1, lv1, mu2, lv2)normal_kl(mu1, lv1, mu2, lv2)
mu1 |
Mean of first distribution (posterior) |
lv1 |
Log variance of first distribution |
mu2 |
Mean of second distribution (prior) |
lv2 |
Log variance of second distribution |
Tensor of KL divergence values
RTQichen. torchdiffeq. https://github.com/rtqichen/torchdiffeq
Solve ODE forward in time. This is the main integration function that matches scTour's odeint behavior.
odeint(func, z0, t, method = "euler", options = list())odeint(func, z0, t, method = "euler", options = list())
func |
ODE function (nn_module with forward(t, z) method) |
z0 |
Initial state tensor |
t |
Time points tensor (must be sorted) |
method |
Integration method: "euler" (default, matches scTour) or "rk4" |
options |
List with optional step_size |
Tensor of states at each time point (n_times x n_latent)
## Not run: states <- odeint(ode_func, z0, t, method = "euler") ## End(Not run)## Not run: states <- odeint(ode_func, z0, t, method = "euler") ## End(Not run)
Visualize training and validation loss curves from CellODE training.
plot_loss( trainer, smooth = FALSE, title = "Training Progress", colors = c(Training = "#2E86AB", Validation = "#E94F37") )plot_loss( trainer, smooth = FALSE, title = "Training Progress", colors = c(Training = "#2E86AB", Validation = "#E94F37") )
trainer |
Trainer object with training log |
smooth |
Apply loess smoothing (default: FALSE) |
title |
Plot title (default: "Training Progress") |
colors |
Colors for train/validation lines |
ggplot2 object
## Not run: plot_loss(trainer) ## End(Not run)## Not run: plot_loss(trainer) ## End(Not run)
Visualize pseudotime trajectory with direction arrows computed using grid-based vector field algorithm. Colors cells by pseudotime gradient and overlays arrows showing developmental direction.
plot_pseudotime( seurat_obj, time_key = "pseudotime", embedding_key = "umap", dims = c(1, 2), point_size = 0.8, point_palette = "Spectral", point_palcolor = NULL, point_alpha = 0.8, arrow_show = TRUE, arrow_density = 1.5, arrow_length = 0.12, arrow_width = 0.6, arrow_color = "grey20", arrow_angle = 25, arrow_type = c("open", "closed"), show_axes = TRUE, aspect_ratio = 1, title = "Pseudotime", subtitle = NULL, xlab = NULL, ylab = NULL, legend_position = "right", return_layer = FALSE, seed = 11 )plot_pseudotime( seurat_obj, time_key = "pseudotime", embedding_key = "umap", dims = c(1, 2), point_size = 0.8, point_palette = "Spectral", point_palcolor = NULL, point_alpha = 0.8, arrow_show = TRUE, arrow_density = 1.5, arrow_length = 0.12, arrow_width = 0.6, arrow_color = "grey20", arrow_angle = 25, arrow_type = c("open", "closed"), show_axes = TRUE, aspect_ratio = 1, title = "Pseudotime", subtitle = NULL, xlab = NULL, ylab = NULL, legend_position = "right", return_layer = FALSE, seed = 11 )
seurat_obj |
Seurat object with dimensionality reduction and pseudotime |
time_key |
Column name in metadata containing pseudotime (default: "pseudotime") |
embedding_key |
Embedding to use (default: "umap") |
dims |
Dimensions to plot (default: c(1, 2)) |
point_size |
Point size (default: 0.8) |
point_palette |
Color palette (default: "Spectral") |
point_palcolor |
Custom color vector (overrides palette) |
point_alpha |
Point transparency (default: 0.8) |
arrow_show |
Show direction arrows (default: TRUE) |
arrow_density |
Arrow density (default: 1.5) |
arrow_length |
Arrow length multiplier (default: 0.12) |
arrow_width |
Arrow line width (default: 0.6) |
arrow_color |
Arrow color (default: "grey20") |
arrow_angle |
Arrow head angle (default: 25) |
arrow_type |
Arrow type: "open" or "closed" (default: "open") |
show_axes |
Show axes (default: TRUE) |
aspect_ratio |
Aspect ratio (default: 1) |
title |
Plot title (default: "Pseudotime") |
subtitle |
Plot subtitle (default: NULL) |
xlab |
X-axis label (default: auto) |
ylab |
Y-axis label (default: auto) |
legend_position |
Legend position (default: "right") |
return_layer |
Return layers for custom assembly (default: FALSE) |
seed |
Random seed (default: 11) |
ggplot2 object or list of layers
## Not run: # Basic pseudotime visualization plot_pseudotime(seurat_obj, time_key = "pseudotime") # Custom styling plot_pseudotime(seurat_obj, point_palette = "viridis", arrow_color = "darkred", arrow_density = 0.5) ## End(Not run)## Not run: # Basic pseudotime visualization plot_pseudotime(seurat_obj, time_key = "pseudotime") # Custom styling plot_pseudotime(seurat_obj, point_palette = "viridis", arrow_color = "darkred", arrow_density = 0.5) ## End(Not run)
Visualize cellular dynamics vector field on 2D embedding using distance-weighted directional arrows. Computes direction vectors at grid centers by combining distance-based weights and pseudotime gradients from the latent space dynamics. Supports three rendering modes: raw (per-cell arrows), grid (smoothed arrows), and stream (streamlines for flow visualization).
plot_vector_field( seurat_obj, zs_key = "X_zs", vf_key = "X_VF", embedding_key = "umap", t_key = NULL, plot_type = c("grid", "raw", "stream"), reverse = FALSE, n_neigh = 20, var_transform = FALSE, scale = 10, self_transition = FALSE, point_size = 0.5, point_palette = "Spectral", point_palcolor = NULL, point_alpha = 0.8, arrow_density = 0.5, arrow_length = 0.15, arrow_width = 0.6, arrow_color = "grey20", arrow_angle = 25, arrow_head_size = 50, arrow_type = c("open", "closed"), smooth = 0.5, grid_density = 1, min_mass = 1, stream_n = 15, stream_L = 5, stream_color = NULL, stream_width = c(0.2, 0.8), color_by = NULL, show_axes = TRUE, aspect_ratio = 1, title = "Vector Field", subtitle = NULL, xlab = NULL, ylab = NULL, legend_position = "right", return_layer = FALSE, seed = 11 )plot_vector_field( seurat_obj, zs_key = "X_zs", vf_key = "X_VF", embedding_key = "umap", t_key = NULL, plot_type = c("grid", "raw", "stream"), reverse = FALSE, n_neigh = 20, var_transform = FALSE, scale = 10, self_transition = FALSE, point_size = 0.5, point_palette = "Spectral", point_palcolor = NULL, point_alpha = 0.8, arrow_density = 0.5, arrow_length = 0.15, arrow_width = 0.6, arrow_color = "grey20", arrow_angle = 25, arrow_head_size = 50, arrow_type = c("open", "closed"), smooth = 0.5, grid_density = 1, min_mass = 1, stream_n = 15, stream_L = 5, stream_color = NULL, stream_width = c(0.2, 0.8), color_by = NULL, show_axes = TRUE, aspect_ratio = 1, title = "Vector Field", subtitle = NULL, xlab = NULL, ylab = NULL, legend_position = "right", return_layer = FALSE, seed = 11 )
seurat_obj |
Seurat object containing embeddings and CellODE results |
zs_key |
Key for latent space in misc or reductions (default: "X_zs") |
vf_key |
Key for vector field in misc (default: "X_VF") |
embedding_key |
Embedding for visualization (default: "umap") |
t_key |
Key in metadata for pseudotime (default: NULL) |
plot_type |
Rendering mode: "raw", "grid", or "stream" (default: "grid") |
reverse |
Reverse vector field direction (default: FALSE) |
n_neigh |
Number of neighbors for similarity calculation (default: 20) |
var_transform |
Variance-stabilizing transformation (default: FALSE) |
scale |
Scale factor for cosine similarity (default: 10) |
self_transition |
Include self-transition (default: FALSE) |
point_size |
Size of background points (default: 0.5) |
point_palette |
Color palette for points (default: "Spectral") |
point_palcolor |
Custom color vector (overrides palette) |
point_alpha |
Point transparency (default: 0.8) |
arrow_density |
Proportion of arrows to display (default: 0.5) |
arrow_length |
Arrow length multiplier (default: 0.15) |
arrow_width |
Arrow line width (default: 0.6) |
arrow_color |
Arrow color (default: "grey20") |
arrow_angle |
Arrow head angle in degrees (default: 25) |
arrow_head_size |
Arrow head size scaling (default: 50) |
arrow_type |
Arrow head type: "open" or "closed" (default: "open") |
smooth |
Smoothing factor for grid (default: 0.5) |
grid_density |
Grid density factor (default: 1.0) |
min_mass |
Minimum mass threshold for grid points (default: 1) |
stream_n |
Number of streamlines (default: 15) |
stream_L |
Streamline length parameter (default: 5) |
stream_color |
Streamline color (default: NULL, uses gradient) |
stream_width |
Streamline width range (default: c(0.2, 0.8)) |
color_by |
Variable to color points by (default: NULL) |
show_axes |
Show axes and labels (default: TRUE) |
aspect_ratio |
Plot aspect ratio (default: 1) |
title |
Plot title (default: "Vector Field") |
subtitle |
Plot subtitle (default: NULL) |
xlab |
X-axis label (default: auto) |
ylab |
Y-axis label (default: auto) |
legend_position |
Legend position (default: "right") |
return_layer |
Return layers for custom assembly (default: FALSE) |
seed |
Random seed (default: 11) |
ggplot2 object or list of layers if return_layer = TRUE
## Not run: # Basic usage after CellODE training plot_vector_field(seurat_obj, plot_type = "grid") # Raw mode with custom colors plot_vector_field(seurat_obj, plot_type = "raw", arrow_color = "darkred", arrow_density = 0.3) # Color points by pseudotime plot_vector_field(seurat_obj, color_by = "pseudotime", point_palette = "viridis") ## End(Not run)## Not run: # Basic usage after CellODE training plot_vector_field(seurat_obj, plot_type = "grid") # Raw mode with custom colors plot_vector_field(seurat_obj, plot_type = "raw", arrow_color = "darkred", arrow_density = 0.3) # Color points by pseudotime plot_vector_field(seurat_obj, color_by = "pseudotime", point_palette = "viridis") ## End(Not run)
High-quality visualization for pseudotime and vector field
Zaoqu Liu
Functions for predicting pseudotime, latent space, and vector field
Predict latent representations for query cells.
predict_latentsp( trainer, query_seurat, mode = "fine", alpha_z = 0.5, alpha_predz = 0.5, step_size = NULL, step_wise = FALSE, batch_size = NULL, assay = "RNA" )predict_latentsp( trainer, query_seurat, mode = "fine", alpha_z = 0.5, alpha_predz = 0.5, step_size = NULL, step_wise = FALSE, batch_size = NULL, assay = "RNA" )
trainer |
Trained Trainer object |
query_seurat |
Query Seurat object |
mode |
Prediction mode: "fine" or "coarse" (default: "fine") |
alpha_z |
Weight for encoder-derived latent (default: 0.5) |
alpha_predz |
Weight for ODE-derived latent (default: 0.5) |
step_size |
Step size for integration (default: NULL) |
step_wise |
Use step-wise integration (default: FALSE) |
batch_size |
Batch size (default: NULL) |
assay |
Assay name (default: "RNA") |
List with mix_zs, zs, pred_zs matrices
Predict latent space for query (unobserved) time intervals. Matches scTour's predict_ltsp_from_time function exactly.
predict_ltsp_from_time( trainer, t, reverse = FALSE, step_wise = TRUE, step_size = NULL, alpha_z = 0.5, alpha_predz = 0.5, k = 20, assay = "RNA" )predict_ltsp_from_time( trainer, t, reverse = FALSE, step_wise = TRUE, step_size = NULL, alpha_z = 0.5, alpha_predz = 0.5, k = 20, assay = "RNA" )
trainer |
Trained Trainer object |
t |
Vector of query time points (values between 0 and 1) |
reverse |
Whether pseudotime was reversed (default: FALSE) |
step_wise |
Use step-wise integration (default: TRUE) |
step_size |
Step size for integration (default: NULL) |
alpha_z |
Weight for encoder-derived latent (default: 0.5) |
alpha_predz |
Weight for ODE-derived latent (default: 0.5) |
k |
Number of nearest neighbors in time space (default: 20) |
assay |
Assay name (default: "RNA") |
Matrix of predicted latent space
Predict developmental pseudotime for query cells.
predict_time(trainer, query_seurat, reverse = FALSE, assay = "RNA")predict_time(trainer, query_seurat, reverse = FALSE, assay = "RNA")
trainer |
Trained Trainer object |
query_seurat |
Query Seurat object |
reverse |
Whether to reverse pseudotime (default: FALSE) |
assay |
Assay name (default: "RNA") |
Numeric vector of pseudotime values
Predict vector field for query cells.
predict_vector_field(trainer, t, z)predict_vector_field(trainer, t, z)
trainer |
Trained Trainer object |
t |
Pseudotime vector |
z |
Latent space matrix |
Matrix of vector field
Post-inference adjustment to reverse the pseudotime.
reverse_time(t)reverse_time(t)
t |
Pseudotime vector |
Reversed pseudotime (1 - t)
Split dataset into training and validation sets.
split_data(X, percent, val_frac = 0.1, loss_mode = "nb")split_data(X, percent, val_frac = 0.1, loss_mode = "nb")
X |
Expression matrix (cells x genes) |
percent |
Percentage of cells for training |
val_frac |
Validation fraction from training set |
loss_mode |
Loss mode for data transformation |
List with train_data and val_data
Split indices into training and validation sets.
split_index(n_cells, percent, val_frac = 0.1)split_index(n_cells, percent, val_frac = 0.1)
n_cells |
Total number of cells |
percent |
Percentage for training |
val_frac |
Validation fraction |
List with train_idx and val_idx
Complete model combining VAE and Neural ODE for cellular dynamics inference. This is the core model of CellODE.
TNODE( n_int, n_latent = 5L, n_ode_hidden = 25L, n_vae_hidden = 128L, batch_norm = FALSE, ode_method = "euler", step_size = NULL, alpha_recon_lec = 0.5, alpha_recon_lode = 0.5, alpha_kl = 1, loss_mode = "nb" )TNODE( n_int, n_latent = 5L, n_ode_hidden = 25L, n_vae_hidden = 128L, batch_norm = FALSE, ode_method = "euler", step_size = NULL, alpha_recon_lec = 0.5, alpha_recon_lode = 0.5, alpha_kl = 1, loss_mode = "nb" )
n_int |
Number of input features (genes) |
n_latent |
Dimensionality of latent space (default: 5) |
|
Hidden layer size for ODE function (default: 25) |
|
|
Hidden layer size for VAE (default: 128) |
|
batch_norm |
Whether to include BatchNorm layer (default: FALSE) |
ode_method |
ODE solver method (default: "euler") |
step_size |
Step size multiplier for integration (NULL for default) |
alpha_recon_lec |
Weight for encoder reconstruction loss (default: 0.5) |
alpha_recon_lode |
Weight for ODE reconstruction loss (default: 0.5) |
alpha_kl |
Weight for KL divergence (default: 1.0) |
loss_mode |
Loss function: "mse", "nb", or "zinb" (default: "nb") |
nn_module for TNODE model
R6 class for implementing the CellODE training process.
seurat_objSeurat object for training
modelTNODE model
optimizerAdam optimizer
deviceComputation device
logTraining log
time_reverseWhether to reverse time
model_kwargsModel configuration
new()
Initialize Trainer
Trainer$new( seurat_obj, assay = "RNA", slot = NULL, percent = NULL, n_latent = 5L, n_ode_hidden = 25L, n_vae_hidden = 128L, batch_norm = FALSE, ode_method = "euler", step_size = NULL, alpha_recon_lec = 0.5, alpha_recon_lode = 0.5, alpha_kl = 1, loss_mode = "nb", nepoch = NULL, batch_size = 1024L, drop_last = FALSE, lr = 0.001, wt_decay = 1e-06, eps = 0.01, random_state = 0L, val_frac = 0.1, use_gpu = TRUE )
seurat_objSeurat object with expression data
assayAssay to use (default: "RNA")
slotSlot to use (default: "counts" for nb/zinb, "data" for mse)
percentPercentage of cells for training (default: auto)
n_latentLatent space dimensions (default: 5)
n_ode_hiddenODE hidden layer size (default: 25)
n_vae_hiddenVAE hidden layer size (default: 128)
batch_normUse batch normalization (default: FALSE)
ode_methodODE solver (default: "euler")
step_sizeStep size multiplier (default: NULL)
alpha_recon_lecEncoder reconstruction weight (default: 0.5)
alpha_recon_lodeODE reconstruction weight (default: 0.5)
alpha_klKL divergence weight (default: 1.0)
loss_modeLoss mode: "mse", "nb", "zinb" (default: "nb")
nepochNumber of epochs (default: auto)
batch_sizeBatch size (default: 1024)
drop_lastDrop last incomplete batch (default: FALSE)
lrLearning rate (default: 1e-3)
wt_decayWeight decay (default: 1e-6)
epsAdam epsilon (default: 0.01)
random_stateRandom seed (default: 0)
val_fracValidation fraction (default: 0.1)
use_gpuUse GPU if available (default: TRUE)
train()
Train the model
Trainer$train()
get_time()
Get pseudotime for all cells
Trainer$get_time()
Numeric vector of pseudotime values
get_vector_field()
Get vector field
Trainer$get_vector_field(t, z)
tPseudotime vector
zLatent space matrix
Matrix of vector field
get_latentsp()
Get latent space representation
Trainer$get_latentsp( alpha_z = 0.5, alpha_predz = 0.5, step_size = NULL, step_wise = FALSE, batch_size = NULL )
alpha_zWeight for encoder-derived latent (default: 0.5)
alpha_predzWeight for ODE-derived latent (default: 0.5)
step_sizeStep size for integration (default: NULL)
step_wiseUse step-wise integration (default: FALSE)
batch_sizeBatch size (default: NULL for all)
List with mix_zs, zs, pred_zs matrices
save_model()
Save trained model
Trainer$save_model(path)
pathFile path (without extension)
load_model()
Load trained model
Trainer$load_model(path)
pathFile path (without extension)
clone()
The objects of this class are cloneable with this method.
Trainer$clone(deep = FALSE)
deepWhether to make a deep clone.
Mathematical utilities for loss computation and ODE solving
Functions for vector field computation and visualization
Calculate weighted unitary displacement vectors under embedding. This function matches scTour's vector_field_embedding exactly. The calculation borrows the ideas from scvelo.
vector_field_embedding(T_mat, E, scale = 10, self_transition = FALSE)vector_field_embedding(T_mat, E, scale = 10, self_transition = FALSE)
T_mat |
Cosine similarity sparse matrix (from cosine_similarity) |
E |
Embedding matrix (cells x 2) |
scale |
Scale factor for cosine similarity (default: 10) |
self_transition |
Include self-transition (default: FALSE) |
Matrix of displacement vectors
Estimate displacement vectors on a grid. This function matches scTour's vector_field_embedding_grid exactly. The calculation borrows the ideas from scvelo.
vector_field_embedding_grid(E, V, smooth = 0.5, stream = FALSE, density = 1)vector_field_embedding_grid(E, V, smooth = 0.5, stream = FALSE, density = 1)
E |
Embedding matrix (cells x 2) |
V |
Displacement vectors (cells x 2) |
smooth |
Smoothing factor for Gaussian pdf (default: 0.5) |
stream |
Adjust for streamplot (default: FALSE) |
density |
Grid density (default: 1.0) |
List with E_grid and V_grid