title: Diffusion Module v2 Algorithm Notes tags: [inventory, edm, diffusion, algorithm, boltz, reference] created: 2024-04-21 updated: 2024-04-21 status: active related:
Diffusion Module v2 Algorithm Notes
Architecture Overview
Model: EDM (Elucidating the Design Space of Diffusion-Based Generative Models, Karras et al.)
Core model (F_θ): self.score_model — atom↔token diffusion transformer
Preconditioning wrapper: preconditioned_network_forward wraps F_θ with input/output scaling
Preconditioning
x_denoised = c_skip(σ) · x_noisy + c_out(σ) · F_θ(c_in(σ) · x_noisy, c_noise(σ))
| Function | Formula | Purpose |
|---|---|---|
c_in(σ) | 1 / √(σ² + σ_data²) | Normalize input to unit variance |
c_noise(σ) | 0.25 · log(σ / σ_data) | Noise-level conditioning signal for F_θ |
c_skip(σ) | σ_data² / (σ² + σ_data²) | Skip connection: →1 at low noise (identity), →0 at high noise |
c_out(σ) | σ·σ_data / √(σ² + σ_data²) | Scale network output: unit-variance target at every σ |
Without preconditioning, F_θ would need to learn very different functions at different noise levels. The skip connection handles the identity-at-low-noise part automatically.
Training
Single-step denoising objective. No multi-step unrolling at training time.
- Get ground-truth coordinates
x₀from dataset (feats["coords"]) - Sample noise level
σper sample from log-normal distribution:σ = σ_data · exp(P_mean + P_std · N(0,1)) - Add noise:
x_noisy = x₀ + σ · εwhereε ~ N(0,I) - Denoise:
x̂₀ = preconditioned_network_forward(x_noisy, σ) = c_skip·x_noisy + c_out·F_θ(c_in·x_noisy, c_noise(σ)) - Compute loss against GT coordinates
Loss (compute_loss)
- Build
resolved_atom_mask(optionally filtered by pLDDT) - Compute per-atom alignment weights: upweight nucleotides (5×) and ligands (10×)
- Rigid-align GT coords → denoised coords (removes global SE(3) pose, measures local structure only)
- Weighted MSE:
Σ w_i · (x̂₀ - x_aligned)² / Σ 3·w_i, withw_i= align_weights × resolved_mask - Weight by
loss_weight(σ) = (σ² + σ_data²) / (σ·σ_data)²to balance across noise levels - Optional smooth lDDT auxiliary loss (local distance agreement)
Key point: Training loss uses x₀ (actual GT), never a “previous step” prediction. No reverse process.
Sampling (Inference)
Iterative denoising from pure noise to clean structure.
- Start from pure noise:
x = σ_max · ε - Get schedule: σ values from
σ_max → 0via rho-spaced geometric progression - For each step
(σ_{t-1}, σ_t, γ):- Random augmentation (rotation + translation) of current coords
- Optional stochastic re-noising:
t̂ = σ_{t-1}·(1+γ), inject noise with variancet̂² - σ_{t-1}² x_noisy = x + √(noise_var) · ε- Denoise:
x̂₀ = preconditioned_network_forward(x_noisy, t̂)— same function as training - Optional: steering (FK resampling, physical/contact guidance)
- Optional: alignment reverse diffusion (rigid-align noisy to denoised)
- Euler step:
x_next = x_noisy + step_scale · (σ_t - t̂) · (x_noisy - x̂₀) / t̂
- Return final
xas the predicted structure
The Euler step explained
(x_noisy - x̂₀) / t̂ ≈ score function ∇log p(x|σ). It points in the direction to move x_noisy toward real data. The sampler moves along this direction by the distance (σ_t - t̂) between current and next noise level.
Key Relationships
- Training: learns
x₀ = D(x_noisy, σ)at arbitrary noise levels in one shot - Sampling: uses the same denoiser repeatedly as a score estimator in an ODE solver
- Connection:
score(x, σ) = (D(x, σ) - x) / σ²where D is the denoiser multiplicity: number of parallel samples per input (1 by default, more for steering/particle filters)
Noise Schedule
Training: log-normal σ = σ_data · exp(P_mean + P_std · N(0,1))
Sampling: σ_k = (σ_max^(1/ρ) + k/(N-1) · (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ · σ_data