title: EDM Training for Pomodoro - Session Notes tags: [journal, edm, pomodoro, training, implementation, session-log] created: 2024-04-23 updated: 2024-04-23 status: active related:
- “diffusionv2-algorithm”
- “edm-algorithm-review”
- “edm-pomodoro-algorithm-review”
- “edm-noise-exploration”
EDM Training for Pomodoro - Session Notes
Summary
Implemented EDM (Elucidating the Design Space of Diffusion-Based Generative Models) training protocol for the Pomodoro protein structure model, replacing the previous simple-noise autoregressive approach. Also added gradient checkpointing for memory efficiency.
Key Files
Model (modified)
/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/model.py- Added
ckpt()helper function (L8-11) for gradient checkpointing - Added
use_checkpoint: bool = FalsetoModel.Config(L593) - Stores
self.use_checkpointin__init__(L614) - All 9 GeometricTransformer calls per layer now use
ckpt()(L651-672)
- Added
Training notebook (primary work file)
/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/workspace/pomodoro/edm/debug_training.py- Contains all EDM functions, training step, validation, and inference
Supporting modules (read-only, referenced)
/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/context.py—mapping_and_connectivity_levels,reduce_coordinates,broadcast_coordinates,coordinates_levels,center_coordinates/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/sampler.py— Original multiscale sampler (reference for noise distributions)/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/pomodoro/src/pomodoro/objectives.py—predicted_displacement(not used in EDM, but reference)/home/khaos/syncfolder/lemna/lemna-dev/models/pomodoro/workspace/pomodoro/edm/dashboard.py— DashMetrics dashboard for training monitoring
Reference (Boltz, read-only)
/home/khaos/syncfolder/lemna/lemna-dev/boltz/src/boltz/model/modules/diffusionv2.py— Original EDM implementation studied for protocol design/home/khaos/syncfolder/lemna/lemna-dev/boltz/src/boltz/model/loss/diffusionv2.py— Original EDM loss functions
Notes
- diffusionv2-algorithm — Algorithm notes from studying Boltz EDM
What Was Implemented
1. EDM Preconditioning (debug_training.py L171-198)
edm_defaults()— hyperparameters: sigma_data=16, P_mean=-1.2, P_std=1.5, sigma_min=1e-3, sigma_max=160, rho=7noise_distribution()— log-normal σ samplingc_skip(),c_out(),c_in(),c_noise(),loss_weight()— EDM preconditioning functions- All take
SimpleNamespace hpparameter (notself)
2. Noise Level Conditioning (debug_training.py L224-239)
c_noise(σ)is concatenated toqeas a 31st feature (was F=30, now F=31)- Model config updated:
StateEmbeddingModel.Config(F=31, ...) edm_denoise()concatenatesc_noise(σ).expand(N,1)toqe, passesc_in(σ) * X_noisedas coordinates
3. Multiscale Noise (debug_training.py L201-221)
edm_multiscale_noise(X0, Mr, Mc, C, σ)— hierarchical noising through atom/residue/chain levels- Noise is decomposed:
σ² = σ_a² + σ_r² + σ_c²with weights (1.3, 1.6, 8.2) from original sampler - Chain noise → broadcast to residues → broadcast to atoms, plus per-level noise
- Used in both training (
step) and inference (edm_sampleinitial state)
4. Multiscale Denoising & Loss (debug_training.py L224-288)
edm_denoise()returns(Xa_hat, Xr_hat, Xc_hat)— all three denoised levels- Residue/chain noised coords computed via
reduce_coordinates(X_noised, Mar/Mrc) - Preconditioning applied at each level:
c_skip(σ) * X_noised_level + c_out(σ) * model_output_level step()computes MSE loss at atom, residue, and chain levels with rigid alignment + push-pull loss- Metrics logged: loss, la, lr, lc, rmsd, lddt
5. EDM Sampling (debug_training.py L407-430)
edm_sample()— iterative Euler sampling from pure noise- Initial state:
edm_multiscale_noise(zeros, Mr, Mc, C, σ_max)for multiscale initial noise - Optional stochastic sampler via
gamma(re-noising) andnoise_scaleparameters - Returns denoised trajectory for visualization
6. Validation (debug_training.py L318-340)
validate()— runs 20-step EDM sampling, computes RMSD and lDDT against GT- Logs
val_rmsdandval_lddtto dashboard - Returns trajectory + GT coords for PDB saving
- Called every 100 training steps
7. Gradient Checkpointing (model.py L8-11, L593, L614, L651-672)
ckpt(use_ckpt, fn, *args)— wrapper fortorch.utils.checkpoint.checkpoint- When
use_checkpoint=True, each GeometricTransformer forward call is checkpointed - Reduces activation memory ~36x (9 GT calls × 4 layers) at ~2x compute cost
- Default
False; enabled in debug_training.py config
Architecture Notes
Model inputs/outputs
model.forward(qa, Xa, C, Mr, Mc)→(Xa, Xr, Xc)qa:[N, 31](was 30, +1 for c_noise conditioning)Xa:[N, 3]atom coords (preconditioned:c_in(σ) * X_noised)- Returns 3 levels: atom/residue/chain decoded coords (raw, before preconditioning wrap)
Data shapes (no batch dimension)
X:[N, 3]— N atoms, 3D coordinatesqe:[N, 30]— atom element one-hot encodingMr:[N, Nr]— atom-to-residue mappingMc:[N, Nc]— atom-to-chain mappingC:[N, N]— connectivity matrix
Current config
- S=64, K=8, L=8, r=1e-3, Hs=64
- atom: F=31 (30 elements + 1 noise level)
- F=31 for stm_a, F=1 for stm_r and stm_c
Potential Next Steps
- Tune EDM hyperparameters (sigma_data, P_mean, P_std, rho)
- Add per-noise-level loss tracking (bin by σ ranges)
- Add EMA smoothing on validation metrics
- Monitor σ distribution during training
- Tune multiscale noise weights (currently 1.3/1.6/8.2 from old sampler)
- Add gradient norm monitoring (preconditioning can produce large gradients at low σ)