CellODE is an R package for inferring cellular dynamics from single-cell RNA sequencing data using deep generative models. The package combines: - Variational Autoencoders (VAE) for dimensionality reduction - Neural Ordinary Differential Equations (Neural ODE) for continuous dynamics modeling
Key capabilities include:
CellODE works with Seurat objects. Your data should have: - Raw UMI
counts in counts slot (for
nb/zinb modes) - Or log-normalized data in
data slot (for mse mode)
# Load your Seurat object
seurat_obj <- readRDS("your_data.rds")
# Standard preprocessing (if not already done)
seurat_obj <- NormalizeData(seurat_obj)
seurat_obj <- FindVariableFeatures(seurat_obj, nfeatures = 2000)
seurat_obj <- ScaleData(seurat_obj)
seurat_obj <- RunPCA(seurat_obj)
seurat_obj <- RunUMAP(seurat_obj, dims = 1:30)# Create trainer
trainer <- Trainer$new(
seurat_obj = seurat_obj,
n_latent = 5, # Latent space dimensions
n_ode_hidden = 25, # ODE network hidden units
n_vae_hidden = 128, # VAE hidden units
loss_mode = "nb", # Negative binomial for UMI counts
nepoch = 100, # Training epochs
batch_size = 1024 # Batch size
)
# Train the model
trainer$train()
# Plot training history
plot_training_history(trainer)Here’s a minimal working example:
library(CellODE)
library(torch)
# Create synthetic data
set.seed(42)
n_cells <- 200
n_genes <- 50
# Simulate expression with pseudo-trajectory
true_time <- runif(n_cells)
X <- matrix(0, n_cells, n_genes)
for (i in 1:n_cells) {
X[i, 1:25] <- rpois(25, lambda = 20 * (1 - true_time[i]) + 2)
X[i, 26:50] <- rpois(25, lambda = 20 * true_time[i] + 2)
}
# Create model directly (without Seurat)
model <- TNODE(
n_int = n_genes,
n_latent = 5L,
n_ode_hidden = 25L,
n_vae_hidden = 64L,
loss_mode = "nb"
)
# Prepare tensors
X_log <- log1p(X)
X_tensor <- torch::torch_tensor(X_log, dtype = torch::torch_float())
y_tensor <- X_tensor$sum(dim = 2)
# Train
optimizer <- torch::optim_adam(model$parameters, lr = 0.001)
for (epoch in 1:50) {
optimizer$zero_grad()
result <- model(X_tensor, y_tensor)
result$loss$backward()
optimizer$step()
if (epoch %% 10 == 0) {
cat(sprintf("Epoch %d: loss = %.4f\n", epoch, result$loss$item()))
}
}
# Get pseudotime
model$eval()
torch::with_no_grad({
enc_out <- model$encoder(X_tensor)
pseudotime <- as.numeric(enc_out$t$squeeze(-1)$cpu())
})
# Check correlation with true time
cat(sprintf("Correlation: %.4f\n", abs(cor(pseudotime, true_time))))| Data Type | Recommended Mode |
|---|---|
| UMI counts | "nb" (default) |
| Sparse UMI counts | "zinb" |
| Log-normalized | "mse" |
| Trajectory Complexity | Recommended n_latent |
|---|---|
| Simple linear | 3-5 |
| Single branch | 5-10 |
| Multiple branches | 10-20 |