--- title: "Advanced Usage and Best Practices" author: "Zaoqu Liu" date: "`r Sys.Date()`" output: rmarkdown::html_vignette: toc: true toc_depth: 3 fig_width: 7 fig_height: 5 vignette: > %\VignetteIndexEntry{Advanced Usage and Best Practices} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.align = "center", message = FALSE, warning = FALSE, eval = FALSE ) ``` ## Introduction This vignette covers advanced features and best practices for using CellODE effectively in your single-cell analysis workflows. We'll discuss: - Model hyperparameter tuning - Handling complex trajectories - Prediction on query datasets - Integration with Seurat workflows - Performance optimization ## Hyperparameter Tuning ### Latent Space Dimensionality The `n_latent` parameter controls the dimensionality of the learned latent space: ```{r latent-dim} # For simple linear trajectories trainer <- Trainer$new(seurat_obj, n_latent = 3) # For complex branching trajectories trainer <- Trainer$new(seurat_obj, n_latent = 10) # For very complex data with multiple lineages trainer <- Trainer$new(seurat_obj, n_latent = 20) ``` **Guidelines:** | Data Complexity | Recommended `n_latent` | |-----------------|------------------------| | Simple linear | 3-5 | | Single branch | 5-10 | | Multiple branches | 10-20 | | Very complex | 20-50 | ### ODE Network Architecture ```{r ode-arch} # Default: lightweight trainer <- Trainer$new( seurat_obj, n_ode_hidden = 25 # Hidden layer size in ODE function ) # For more complex dynamics trainer <- Trainer$new( seurat_obj, n_ode_hidden = 100 ) ``` ### VAE Network Architecture ```{r vae-arch} # Default architecture trainer <- Trainer$new( seurat_obj, n_vae_hidden = 128 ) # For large gene sets or complex data trainer <- Trainer$new( seurat_obj, n_vae_hidden = 256, batch_norm = TRUE # Can help with large networks ) ``` ### Loss Function Selection ```{r loss-mode} # Negative Binomial (default) - Best for UMI counts trainer <- Trainer$new(seurat_obj, loss_mode = "nb") # Zero-Inflated NB - For data with many zeros (e.g., SMART-seq) trainer <- Trainer$new(seurat_obj, loss_mode = "zinb") # MSE - For already normalized data trainer <- Trainer$new(seurat_obj, loss_mode = "mse", slot = "data") ``` ### Loss Weight Balancing The loss function combines multiple terms with adjustable weights: ```{r loss-weights} trainer <- Trainer$new( seurat_obj, alpha_recon_lec = 0.5, # Encoder reconstruction weight alpha_recon_lode = 0.5, # ODE reconstruction weight alpha_kl = 1.0 # KL divergence weight ) # If latent space is too smooth, increase KL weight trainer <- Trainer$new(seurat_obj, alpha_kl = 2.0) # If reconstruction is poor, decrease KL weight trainer <- Trainer$new(seurat_obj, alpha_kl = 0.5) ``` ## Training Configuration ### Learning Rate Scheduling ```{r training} # Standard training trainer <- Trainer$new( seurat_obj, lr = 1e-3, # Learning rate wt_decay = 1e-6, # L2 regularization eps = 0.01 # Adam optimizer epsilon ) # For unstable training, use smaller learning rate trainer <- Trainer$new(seurat_obj, lr = 1e-4) # For faster convergence on simple data trainer <- Trainer$new(seurat_obj, lr = 5e-3) ``` ### Batch Size and Epochs ```{r batch-epoch} # Auto-determined (recommended) trainer <- Trainer$new(seurat_obj) # Manual specification trainer <- Trainer$new( seurat_obj, batch_size = 512, # Smaller batch = more noise, larger = more memory nepoch = 200 # More epochs for complex data ) ``` ### Data Subsampling For large datasets, train on a subset: ```{r percent} # Auto-determined based on dataset size # > 10,000 cells: 20% # <= 10,000 cells: 90% # Manual override trainer <- Trainer$new(seurat_obj, percent = 0.3) # Use 30% of cells ``` ## Handling Complex Trajectories ### Branching Trajectories For data with multiple differentiation branches: ```{r branching} trainer <- Trainer$new( seurat_obj, n_latent = 15, # More dimensions to capture branches n_ode_hidden = 50, # Larger ODE network nepoch = 300 # More training epochs ) # After training, use step-wise integration for better accuracy latent <- trainer$get_latentsp(step_wise = TRUE) ``` ### Cyclic Trajectories For cell cycle data: ```{r cyclic} # Cyclic data may have pseudotime wrapping issues # Consider using larger latent space trainer <- Trainer$new(seurat_obj, n_latent = 10) # Manual time reversal may be needed pseudotime <- trainer$get_time() # If time direction seems wrong: pseudotime <- reverse_time(pseudotime) ``` ## Prediction on Query Data ### Coarse vs Fine Mode ```{r prediction-modes} # Coarse mode: Fast, independent of training data query_latent <- predict_latentsp(trainer, query_seurat, mode = "coarse") # Fine mode: More accurate, uses training data as reference query_latent <- predict_latentsp(trainer, query_seurat, mode = "fine") ``` ### Predicting Future States Interpolate latent space at unobserved time points: ```{r future-states} # Define query time points query_times <- seq(0.5, 1.0, by = 0.1) # Predict latent representations future_latent <- predict_ltsp_from_time( trainer, t = query_times, k = 20, # Number of neighbors for interpolation step_wise = TRUE # More accurate integration ) ``` ## Integration with Seurat Workflows ### Adding Results to Seurat Object ```{r seurat-integration} # Pseudotime seurat_obj$cellode_time <- trainer$get_time() # Latent space as dimensional reduction latent <- trainer$get_latentsp() seurat_obj[["cellode"]] <- Seurat::CreateDimReducObject( embeddings = latent$mix_zs, key = "CELLODE_", assay = "RNA" ) # Vector field vf <- trainer$get_vector_field(seurat_obj$cellode_time, latent$mix_zs) seurat_obj@misc$X_VF <- vf seurat_obj@misc$X_zs <- latent$mix_zs ``` ### Downstream Analysis ```{r downstream} # Find genes correlated with pseudotime library(Seurat) # Pseudotime regression using Seurat's AddModuleScore equivalent gene_time_cor <- cor( as.matrix(t(seurat_obj[["RNA"]]@data)), seurat_obj$cellode_time ) # Identify trajectory-associated genes top_genes <- names(sort(abs(gene_time_cor[,1]), decreasing = TRUE)[1:100]) ``` ## Performance Optimization ### GPU Acceleration ```{r gpu} # Auto-detect (default) trainer <- Trainer$new(seurat_obj, use_gpu = TRUE) # Force CPU (for debugging or memory issues) trainer <- Trainer$new(seurat_obj, use_gpu = FALSE) ``` ### Memory Management ```{r memory} # For large datasets, use smaller batch size trainer <- Trainer$new( seurat_obj, batch_size = 256, # Reduce memory usage percent = 0.2 # Train on subset ) # Clear GPU memory after training gc() torch::cuda_empty_cache() # If using CUDA ``` ### Batched Inference ```{r batched-inference} # For large datasets, use batched latent space computation latent <- trainer$get_latentsp(batch_size = 1000) ``` ## Model Persistence ### Saving and Loading ```{r save-load} # Save model trainer$save_model("path/to/model") # Load model for prediction loaded_trainer <- load_model("path/to/model", seurat_obj) # Continue training (if needed) loaded_trainer$train() ``` ### Model Inspection ```{r inspect} # View model architecture print(trainer$model) # Check training history plot_training_history(trainer) # Access model parameters trainer$model_kwargs ``` ## Troubleshooting ### Training Issues | Problem | Solution | |---------|----------| | Loss not decreasing | Reduce learning rate | | Loss oscillating | Increase batch size | | NaN/Inf loss | Check input data normalization | | Slow training | Use GPU, reduce n_latent | ### Quality Issues | Problem | Solution | |---------|----------| | Poor pseudotime ordering | Check time direction, increase epochs | | Noisy vector field | Increase n_neigh in cosine_similarity | | Discontinuous trajectory | Increase n_latent | ### Memory Issues | Problem | Solution | |---------|----------| | GPU out of memory | Reduce batch_size | | CPU out of memory | Reduce percent, use sparse matrices | ## Reproducibility ```{r reproducibility} # Set seeds for reproducibility trainer <- Trainer$new( seurat_obj, random_state = 42 # Fixed seed ) # Full reproducibility requires: # 1. Same random_state # 2. Same data ordering # 3. Same hardware (GPU results may vary) ``` ## Best Practices Summary 1. **Start simple**: Begin with default parameters, adjust as needed 2. **Monitor training**: Use `plot_loss()` to check convergence 3. **Validate results**: Compare with known markers or annotations 4. **Document settings**: Save model parameters for reproducibility 5. **Use appropriate loss**: NB for UMI counts, MSE for normalized data 6. **Consider data size**: Use `percent` for large datasets, more epochs for small ## Session Info ```{r session, eval=TRUE} sessionInfo() ```