This vignette covers advanced features and best practices for using CellODE effectively in your single-cell analysis workflows. We’ll discuss:
The n_latent parameter controls the dimensionality of
the learned latent space:
# For simple linear trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 3)
# For complex branching trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 10)
# For very complex data with multiple lineages
trainer <- Trainer$new(seurat_obj, n_latent = 20)Guidelines:
| Data Complexity | Recommended n_latent |
|---|---|
| Simple linear | 3-5 |
| Single branch | 5-10 |
| Multiple branches | 10-20 |
| Very complex | 20-50 |
# Negative Binomial (default) - Best for UMI counts
trainer <- Trainer$new(seurat_obj, loss_mode = "nb")
# Zero-Inflated NB - For data with many zeros (e.g., SMART-seq)
trainer <- Trainer$new(seurat_obj, loss_mode = "zinb")
# MSE - For already normalized data
trainer <- Trainer$new(seurat_obj, loss_mode = "mse", slot = "data")The loss function combines multiple terms with adjustable weights:
trainer <- Trainer$new(
seurat_obj,
alpha_recon_lec = 0.5, # Encoder reconstruction weight
alpha_recon_lode = 0.5, # ODE reconstruction weight
alpha_kl = 1.0 # KL divergence weight
)
# If latent space is too smooth, increase KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 2.0)
# If reconstruction is poor, decrease KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 0.5)# Standard training
trainer <- Trainer$new(
seurat_obj,
lr = 1e-3, # Learning rate
wt_decay = 1e-6, # L2 regularization
eps = 0.01 # Adam optimizer epsilon
)
# For unstable training, use smaller learning rate
trainer <- Trainer$new(seurat_obj, lr = 1e-4)
# For faster convergence on simple data
trainer <- Trainer$new(seurat_obj, lr = 5e-3)For data with multiple differentiation branches:
For cell cycle data:
Interpolate latent space at unobserved time points:
# Pseudotime
seurat_obj$cellode_time <- trainer$get_time()
# Latent space as dimensional reduction
latent <- trainer$get_latentsp()
seurat_obj[["cellode"]] <- Seurat::CreateDimReducObject(
embeddings = latent$mix_zs,
key = "CELLODE_",
assay = "RNA"
)
# Vector field
vf <- trainer$get_vector_field(seurat_obj$cellode_time, latent$mix_zs)
seurat_obj@misc$X_VF <- vf
seurat_obj@misc$X_zs <- latent$mix_zs# Find genes correlated with pseudotime
library(Seurat)
# Pseudotime regression using Seurat's AddModuleScore equivalent
gene_time_cor <- cor(
as.matrix(t(seurat_obj[["RNA"]]@data)),
seurat_obj$cellode_time
)
# Identify trajectory-associated genes
top_genes <- names(sort(abs(gene_time_cor[,1]), decreasing = TRUE)[1:100])| Problem | Solution |
|---|---|
| Loss not decreasing | Reduce learning rate |
| Loss oscillating | Increase batch size |
| NaN/Inf loss | Check input data normalization |
| Slow training | Use GPU, reduce n_latent |
| Problem | Solution |
|---|---|
| Poor pseudotime ordering | Check time direction, increase epochs |
| Noisy vector field | Increase n_neigh in cosine_similarity |
| Discontinuous trajectory | Increase n_latent |
| Problem | Solution |
|---|---|
| GPU out of memory | Reduce batch_size |
| CPU out of memory | Reduce percent, use sparse matrices |
Start simple: Begin with default parameters, adjust as needed
Monitor training: Use plot_loss()
to check convergence
Validate results: Compare with known markers or annotations
Document settings: Save model parameters for reproducibility
Use appropriate loss: NB for UMI counts, MSE for normalized data
Consider data size: Use percent for
large datasets, more epochs for small
sessionInfo()
#> R version 4.6.0 (2026-04-24)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.4 LTS
#>
#> Matrix products: default
#> BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
#>
#> locale:
#> [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
#> [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
#> [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
#> [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
#> [9] LC_ADDRESS=C LC_TELEPHONE=C
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
#>
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] rmarkdown_2.31
#>
#> loaded via a namespace (and not attached):
#> [1] digest_0.6.39 R6_2.6.1 fastmap_1.2.0 xfun_0.57
#> [5] maketools_1.3.2 cachem_1.1.0 knitr_1.51 htmltools_0.5.9
#> [9] buildtools_1.0.0 lifecycle_1.0.5 cli_3.6.6 sass_0.4.10
#> [13] jquerylib_0.1.4 compiler_4.6.0 sys_3.4.3 tools_4.6.0
#> [17] evaluate_1.0.5 bslib_0.11.0 yaml_2.3.12 otel_0.2.0
#> [21] jsonlite_2.0.0 rlang_1.2.0