title: EDM Training for Pomodoro - Session Notes tags: [journal, edm, pomodoro, training, implementation, session-log] created: 2024-04-23 updated: 2024-04-23 status: active related:


EDM Training for Pomodoro - Session Notes

Summary

Implemented EDM (Elucidating the Design Space of Diffusion-Based Generative Models) training protocol for the Pomodoro protein structure model, replacing the previous simple-noise autoregressive approach. Also added gradient checkpointing for memory efficiency.

Key Files

Model (modified)

  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/model.py
    • Added ckpt() helper function (L8-11) for gradient checkpointing
    • Added use_checkpoint: bool = False to Model.Config (L593)
    • Stores self.use_checkpoint in __init__ (L614)
    • All 9 GeometricTransformer calls per layer now use ckpt() (L651-672)

Training notebook (primary work file)

  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/workspace/pomodoro/edm/debug_training.py
    • Contains all EDM functions, training step, validation, and inference

Supporting modules (read-only, referenced)

  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/context.pymapping_and_connectivity_levels, reduce_coordinates, broadcast_coordinates, coordinates_levels, center_coordinates
  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/sampler.py — Original multiscale sampler (reference for noise distributions)
  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/objectives.pypredicted_displacement (not used in EDM, but reference)
  • /home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/workspace/pomodoro/edm/dashboard.py — DashMetrics dashboard for training monitoring

Reference (Boltz, read-only)

  • /home/khaos/syncfolder/lemna/lemna-dev/boltz/src/boltz/model/modules/diffusionv2.py — Original EDM implementation studied for protocol design
  • /home/khaos/syncfolder/lemna/lemna-dev/boltz/src/boltz/model/loss/diffusionv2.py — Original EDM loss functions

Notes

What Was Implemented

1. EDM Preconditioning (debug_training.py L171-198)

  • edm_defaults() — hyperparameters: sigma_data=16, P_mean=-1.2, P_std=1.5, sigma_min=1e-3, sigma_max=160, rho=7
  • noise_distribution() — log-normal σ sampling
  • c_skip(), c_out(), c_in(), c_noise(), loss_weight() — EDM preconditioning functions
  • All take SimpleNamespace hp parameter (not self)

2. Noise Level Conditioning (debug_training.py L224-239)

  • c_noise(σ) is concatenated to qe as a 31st feature (was F=30, now F=31)
  • Model config updated: StateEmbeddingModel.Config(F=31, ...)
  • edm_denoise() concatenates c_noise(σ).expand(N,1) to qe, passes c_in(σ) * X_noised as coordinates

3. Multiscale Noise (debug_training.py L201-221)

  • edm_multiscale_noise(X0, Mr, Mc, C, σ) — hierarchical noising through atom/residue/chain levels
  • Noise is decomposed: σ² = σ_a² + σ_r² + σ_c² with weights (1.3, 1.6, 8.2) from original sampler
  • Chain noise → broadcast to residues → broadcast to atoms, plus per-level noise
  • Used in both training (step) and inference (edm_sample initial state)

4. Multiscale Denoising & Loss (debug_training.py L224-288)

  • edm_denoise() returns (Xa_hat, Xr_hat, Xc_hat) — all three denoised levels
  • Residue/chain noised coords computed via reduce_coordinates(X_noised, Mar/Mrc)
  • Preconditioning applied at each level: c_skip(σ) * X_noised_level + c_out(σ) * model_output_level
  • step() computes MSE loss at atom, residue, and chain levels with rigid alignment + push-pull loss
  • Metrics logged: loss, la, lr, lc, rmsd, lddt

5. EDM Sampling (debug_training.py L407-430)

  • edm_sample() — iterative Euler sampling from pure noise
  • Initial state: edm_multiscale_noise(zeros, Mr, Mc, C, σ_max) for multiscale initial noise
  • Optional stochastic sampler via gamma (re-noising) and noise_scale parameters
  • Returns denoised trajectory for visualization

6. Validation (debug_training.py L318-340)

  • validate() — runs 20-step EDM sampling, computes RMSD and lDDT against GT
  • Logs val_rmsd and val_lddt to dashboard
  • Returns trajectory + GT coords for PDB saving
  • Called every 100 training steps

7. Gradient Checkpointing (model.py L8-11, L593, L614, L651-672)

  • ckpt(use_ckpt, fn, *args) — wrapper for torch.utils.checkpoint.checkpoint
  • When use_checkpoint=True, each GeometricTransformer forward call is checkpointed
  • Reduces activation memory ~36x (9 GT calls × 4 layers) at ~2x compute cost
  • Default False; enabled in debug_training.py config

Architecture Notes

Model inputs/outputs

  • model.forward(qa, Xa, C, Mr, Mc)(Xa, Xr, Xc)
  • qa: [N, 31] (was 30, +1 for c_noise conditioning)
  • Xa: [N, 3] atom coords (preconditioned: c_in(σ) * X_noised)
  • Returns 3 levels: atom/residue/chain decoded coords (raw, before preconditioning wrap)

Data shapes (no batch dimension)

  • X: [N, 3] — N atoms, 3D coordinates
  • qe: [N, 30] — atom element one-hot encoding
  • Mr: [N, Nr] — atom-to-residue mapping
  • Mc: [N, Nc] — atom-to-chain mapping
  • C: [N, N] — connectivity matrix

Current config

  • S=64, K=8, L=8, r=1e-3, Hs=64
  • atom: F=31 (30 elements + 1 noise level)
  • F=31 for stm_a, F=1 for stm_r and stm_c

Potential Next Steps

  • Tune EDM hyperparameters (sigma_data, P_mean, P_std, rho)
  • Add per-noise-level loss tracking (bin by σ ranges)
  • Add EMA smoothing on validation metrics
  • Monitor σ distribution during training
  • Tune multiscale noise weights (currently 1.3/1.6/8.2 from old sampler)
  • Add gradient norm monitoring (preconditioning can produce large gradients at low σ)