title: Diffusion Module v2 Algorithm Notes tags: [inventory, edm, diffusion, algorithm, boltz, reference] created: 2024-04-21 updated: 2024-04-21 status: active related:


Diffusion Module v2 Algorithm Notes

Architecture Overview

Model: EDM (Elucidating the Design Space of Diffusion-Based Generative Models, Karras et al.) Core model (F_θ): self.score_model — atom↔token diffusion transformer Preconditioning wrapper: preconditioned_network_forward wraps F_θ with input/output scaling

Preconditioning

x_denoised = c_skip(σ) · x_noisy + c_out(σ) · F_θ(c_in(σ) · x_noisy, c_noise(σ))
FunctionFormulaPurpose
c_in(σ)1 / √(σ² + σ_data²)Normalize input to unit variance
c_noise(σ)0.25 · log(σ / σ_data)Noise-level conditioning signal for F_θ
c_skip(σ)σ_data² / (σ² + σ_data²)Skip connection: →1 at low noise (identity), →0 at high noise
c_out(σ)σ·σ_data / √(σ² + σ_data²)Scale network output: unit-variance target at every σ

Without preconditioning, F_θ would need to learn very different functions at different noise levels. The skip connection handles the identity-at-low-noise part automatically.

Training

Single-step denoising objective. No multi-step unrolling at training time.

  1. Get ground-truth coordinates x₀ from dataset (feats["coords"])
  2. Sample noise level σ per sample from log-normal distribution:
    σ = σ_data · exp(P_mean + P_std · N(0,1))
    
  3. Add noise: x_noisy = x₀ + σ · ε where ε ~ N(0,I)
  4. Denoise: x̂₀ = preconditioned_network_forward(x_noisy, σ) = c_skip·x_noisy + c_out·F_θ(c_in·x_noisy, c_noise(σ))
  5. Compute loss against GT coordinates

Loss (compute_loss)

  1. Build resolved_atom_mask (optionally filtered by pLDDT)
  2. Compute per-atom alignment weights: upweight nucleotides (5×) and ligands (10×)
  3. Rigid-align GT coords → denoised coords (removes global SE(3) pose, measures local structure only)
  4. Weighted MSE: Σ w_i · (x̂₀ - x_aligned)² / Σ 3·w_i, with w_i = align_weights × resolved_mask
  5. Weight by loss_weight(σ) = (σ² + σ_data²) / (σ·σ_data)² to balance across noise levels
  6. Optional smooth lDDT auxiliary loss (local distance agreement)

Key point: Training loss uses x₀ (actual GT), never a “previous step” prediction. No reverse process.

Sampling (Inference)

Iterative denoising from pure noise to clean structure.

  1. Start from pure noise: x = σ_max · ε
  2. Get schedule: σ values from σ_max → 0 via rho-spaced geometric progression
  3. For each step (σ_{t-1}, σ_t, γ):
    • Random augmentation (rotation + translation) of current coords
    • Optional stochastic re-noising: t̂ = σ_{t-1}·(1+γ), inject noise with variance t̂² - σ_{t-1}²
    • x_noisy = x + √(noise_var) · ε
    • Denoise: x̂₀ = preconditioned_network_forward(x_noisy, t̂) — same function as training
    • Optional: steering (FK resampling, physical/contact guidance)
    • Optional: alignment reverse diffusion (rigid-align noisy to denoised)
    • Euler step: x_next = x_noisy + step_scale · (σ_t - t̂) · (x_noisy - x̂₀) / t̂
  4. Return final x as the predicted structure

The Euler step explained

(x_noisy - x̂₀) / t̂ ≈ score function ∇log p(x|σ). It points in the direction to move x_noisy toward real data. The sampler moves along this direction by the distance (σ_t - t̂) between current and next noise level.

Key Relationships

  • Training: learns x₀ = D(x_noisy, σ) at arbitrary noise levels in one shot
  • Sampling: uses the same denoiser repeatedly as a score estimator in an ODE solver
  • Connection: score(x, σ) = (D(x, σ) - x) / σ² where D is the denoiser
  • multiplicity: number of parallel samples per input (1 by default, more for steering/particle filters)

Noise Schedule

Training: log-normal σ = σ_data · exp(P_mean + P_std · N(0,1)) Sampling: σ_k = (σ_max^(1/ρ) + k/(N-1) · (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ · σ_data