title: EDM Algorithm Review — Pomodoro Debug Training tags: [inventory, edm, pomodoro, algorithm, debugging, multiscale-noise] created: 2024-04-23 updated: 2024-04-23 status: active related:


EDM Algorithm Review — Pomodoro debug_training

Symptom: Residues don’t move much during sampling

During EDM sampling, atom positions denoise but residue-level positions (the coarse backbone) appear nearly frozen, barely shifting from their starting position. The structure converges to something close-ish in local geometry but wrong in global fold.


Pitfall 1: Multiscale noise is dominated by chain-level, starving residue-level

The math:

w_c = 8.2,  w_r = 1.6,  w_a = 1.3
w_sum = 11.1

sc = sqrt(8.2/11.1) * σ ≈ 0.858 * σ
sr = sqrt(1.6/11.1) * σ ≈ 0.380 * σ
sa = sqrt(1.3/11.1) * σ ≈ 0.342 * σ

These weights come from the old autoregressive sampler (sampler.py L27-29), where ms0=1.3, ms1=1.6, ms2=8.2 are mean sigma values for a completely different sampling regime (iterative local displacement refinement). In the old sampler:

  • Chain sigma scales with n2^(1/3) (chain count to the 1/3 power)
  • Residue sigma scales with n1^(1/3) (residue count)
  • The noise is per-level and interpolated via alpha (a blending parameter)

In EDM, the weights are used to decompose a single σ into three additive variances. This is a fundamentally different usage. The old weights were designed so that larger structures (chains) get more noise in absolute terms — which makes sense for autoregressive refinement. But in EDM, they mean ~73% of the total noise variance goes to the chain level, which broadcasts rigidly to all residues and atoms. This has a specific bad consequence:

During training, the model sees a noised input where residue-level displacement is small relative to chain-level displacement. The chain noise moves the whole residue cluster as a rigid body, so the model learns: “to denoise residue positions, just undo the rigid-body chain shift.” It never needs to learn moving residues relative to each other within a chain — because residue-relative noise is only 14% of total variance.

During sampling, when the model denoises, it correctly removes chain-level noise but has barely learned to adjust residue-level positions, so residues stay put relative to each other.

Fix options:

  • Rebalance weights: try w_c=1.0, w_r=3.0, w_a=3.0 or similar to give residue movement more signal
  • Derive weights from the actual coordinate variance at each level (measure σ_r from data)
  • Make weights a function of σ so that at high noise chain dominates, at low noise residue/atom dominate

Pitfall 2: Noise accumulation via hierarchical broadcast creates correlated noise, not independent noise

The current scheme (L236-244):

offset_c = sc * randn_like(X0c)                          # chain noise
offset_r = sr * randn_like(X0r) + broadcast(offset_c)   # residue = own + chain
offset_a = sa * randn_like(X0) + broadcast(offset_r)     # atom = own + residue + chain

Final atom noise: N_a = sa*ε_a + sr*ε_r_broadcast + sc*ε_c_broadcast

This means the actual variance at the atom level is not sa² + sr² + sc². It is exactly that (since the broadcasts are independent draws), so the total variance is correct: σ². However, the noise is heavily correlated across atoms that share a residue, and across residues that share a chain.

From the model’s perspective, the noised coordinates look like: “all atoms in this residue moved together by some amount, and all residues in this chain moved together by some amount.” The model learns to undo these correlated shifts efficiently (it’s easy — just predict the chain-level displacement and broadcast it back). The hard part — learning to rearrange residues into the correct fold — is undersampled in training because residue-relative variance is low.

This is the core reason residues don’t move: the model literally hasn’t been trained on inputs where residues need large relative rearrangements.


Pitfall 3: sigma_max is probably too low for full noising at the chain level

Current: sigma_max = 80.0 (was 160.0 in session notes, changed at some point)

With sigma_data = 16.0 and P_mean = -1.2, P_std = 1.5:

Typical σ values during training: σ = 16 * exp(-1.2 + 1.5*N(0,1))

  • Median: 16 * exp(-1.2) ≈ 4.8
  • 84th percentile: 16 * exp(0.3) ≈ 21.6
  • 99th percentile: 16 * exp(3.3) ≈ 515 (but σ_max caps at 80)

At σ=80: c_skip = 256/(6400+256) ≈ 0.038, c_out ≈ 15.9. The skip connection is almost zero and the model must predict from scratch.

But at the residue level, with the weight decomposition: sr = 0.380 * 80 ≈ 30.4. A residue-level noise of σ_r≈30 units. What’s the typical spread of residue positions? In proteins, Cα-Cα distances are ~3.8Å, and a 32-residue fragment might span ~50Å. A residue noise of 30 is comparable to the fragment size, so it should be enough to fully noise the residue positions. But remember this is the max — most training samples see much lower σ, where residue noise is a fraction of the spatial extent.

For sampling, σ_max=80 with sigmas scaled by σ_data:

sample_schedule sigmas = (σ_max^(1/ρ) * ...) ^ ρ * σ_data

Wait — there’s a problem here. Looking at sample_schedule (L29-37):

sigmas = (σ_max^(1/ρ) + step_frac * (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ
sigmas = sigmas * σ_data  # <-- THIS LINE

The schedule multiplies by σ_data AGAIN, but σ_max and σ_min are already supposed to be the actual sigma values (not divided by σ_data). In Boltz (L282-293):

sigmas = (self.sigma_max**inv_rho + steps/(N-1) * (self.sigma_min**inv_rho - self.sigma_max**inv_rho)) ** self.rho
sigmas = sigmas * self.sigma_data

Wait, Boltz also multiplies by σ_data. And Boltz’s σ_max = 160 (unscaled). Let me re-examine.

In Boltz: σ_max=160, σ_data=16. The schedule gives sigmas[0] = σ_max * σ_data = 160 * 16 = 2560? No — (160^(1/7) * ... )^7 = 160, then 160 * 16 = 2560. That can’t be right.

Actually re-reading Boltz L282-293 more carefully:

sigmas = (self.sigma_max**inv_rho + steps/(N-1) * (self.sigma_min**inv_rho - self.sigma_max**inv_rho)) ** self.rho
sigmas = sigmas * self.sigma_data

With σ_max=160, σ_data=16: first value = 160 * 16 = 2560. That’s the initial noise level in Å.

But in Pomodoro: σ_max=80, σ_data=16: first value = 80 * 16 = 1280.

And the training noise distribution is: σ = σ_data * exp(P_mean + P_std * randn). Median σ ≈ 4.8.

The initial sampling noise is ~256x the median training noise. This is by design in EDM — the schedule starts at high noise and works down.

But here’s the key: σ_max in Pomodoro (80) is half of Boltz’s (160), but the schedule scales both by σ_data, so the initial noise levels in absolute terms are 1280 vs 2560. Boltz operates on all-atom coordinates of full proteins (which can span 100+ Å). Pomodoro operates on 32-residue fragments. 1280Å of initial noise is massively overkill for the fragment scale, so this is not the issue — the fragments are plenty noised.


Pitfall 4: The loss function alignment direction is backwards

In step() (L287-293):

X0_aligned = superpose(Xa_hat.detach(), X0)
la = mean(sum(square(X0_aligned - Xa_hat), dim=1))

superpose(X_source, X_target) aligns X_target onto X_source. So this aligns GT onto the prediction.

In Boltz (compute_loss L645-648):

atom_coords_aligned_ground_truth = weighted_rigid_align(
    atom_coords.detach(),          # GT (aligned onto pred)
    denoised_atom_coords.detach(), # pred
    align_weights.detach(),
    ...)

weighted_rigid_align(true_coords, pred_coords, ...) — in Boltz, the first argument is true (GT) and the second is predicted. The function aligns true_coords onto pred_coords (see the SVD: B = pred_centered^T @ true_centered). So Boltz also aligns GT onto prediction.

But then the loss is (pred - aligned_GT)², which measures how far the prediction is from the GT after GT has been rotated to match the prediction. This is the standard way — you want the loss to be SE(3)-invariant, so you align GT to pred and compute residuals.

Wait — in superpose (L98-117), the function does:

B = X0c^T @ Xc / N   # X0=source, X=target

Calling superpose(Xa_hat.detach(), X0) means X0=Xa_hat (pred, detached), X=X0 (GT). So B = pred_centered^T @ GT_centered. Then result = (GT - GT_mean) @ R + pred_mean. This rotates GT to match pred.

Then la = |X0_aligned - Xa_hat|² = |aligned_GT - prediction|². This is the same as Boltz. OK, the alignment direction is correct.

However, there’s a subtle issue: the .detach() is on Xa_hat, not on X0. This means gradients flow through X0_aligned which depends on the SVD of the prediction. The SVD has discontinuous gradients at degenerate configurations. Boltz avoids this by detaching both the GT and predicted coords before alignment (.detach() on both). In pomodoro, only the prediction side is detached but the result X0_aligned is not detached from the SVD computation. In practice, this might cause NaN gradients or training instability when atoms become degenerate.


Pitfall 5: Residue/chain loss uses noised coordinates reduced from the atom level, not the ground truth

In edm_denoise (L256-260):

Xr_noised = reduce_coordinates(X_noised, Mar)   # reduce NOISED atom coords
Xc_noised = reduce_coordinates(Xr_noised, Mrc)
Xr_hat = c_skip * Xr_noised + c_out * Xr
Xc_hat = c_skip * Xc_noised + c_out * Xc

The residue/chain denoised outputs use noised residue/chain coords for the skip connection. This is correct in principle (analogous to the atom level).

But in the loss (L290-294):

X0r_aligned = superpose(Xr_hat.detach(), X0r)
lr = mean(sum(square(X0r_aligned - Xr_hat), dim=1))

The target is X0r = reduce_coordinates(X0, Mar) — the clean residue coordinates. This is correct.

However, the issue is that Xr_hat = c_skip * Xr_noised + c_out * Xr_model_output. And Xr_noised is the mean of noised atom positions for each residue. Due to the hierarchical noise, atoms within a residue share the same residue+chain noise offset, so they move coherently — the residue center thus gets noise sr*ε_r + sc*ε_c_broadcast. The atom-level noise sa*ε_a averages out (if enough atoms per residue).

So the residue-level denoising task is: “given a residue center that’s been shifted by ~ σ_r, predict the shift back.” With σ_r being only ~38% of σ, this is a relatively easy task — the model can do it well even without much capacity devoted to the residue track.

The loss on residue/chain levels (lr, lc) therefore trains quickly and converges early, but doesn’t push the model to learn the harder structural rearrangement task. The atom-level loss (la + lb + lnb) is dominant but the push-pull loss lb operates on distances, not absolute positions — so it can be satisfied even with wrong global topology if local geometry is correct.


Pitfall 6: Push-pull loss uses GT geometry computed from aligned prediction, creating a circular dependency

In step() (L296-298):

D0, R0 = kw.extract_geometry(X0_aligned, X0_aligned)
L = kw.connected_distance_matrix(C)
lb, lnb = push_pull_loss(Xa_hat, D0, R0, L)

X0_aligned is the GT aligned onto the prediction. Then D0, R0 are the distances and directions of the aligned GT. The push-pull loss measures how well the prediction preserves these geometric constraints.

This seems fine — the GT is aligned so we compare local geometry in the same frame. But there’s a problem: X0_aligned depends on Xa_hat.detach() (through the SVD). Every training step, a different random alignment is computed, making the distance/direction targets noisy. This adds variance to the gradient but is not fundamentally broken.


Pitfall 7: Sampling — no random augmentation, no centering between steps

In Boltz sampling (L351-357):

random_R, random_tr = compute_random_augmentation(...)
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)  # center
atom_coords = einsum(atom_coords, random_R) + random_tr              # random rotation

In Pomodoro edm_sample (L314-339): No centering, no random augmentation between steps.

EDM sampling is SE(3) equivariant in principle — the denoiser should give the same result regardless of rotation. But in practice, neural networks are not perfectly equivariant. Without random augmentation during sampling, small rotational biases accumulate over 50-300 Euler steps. This can cause the structure to drift in a preferred direction or get stuck in a local optimum.

Additionally, without centering between steps, the coordinates can drift (accumulate translation) since the denoiser is not perfectly translation-invariant. This doesn’t directly explain frozen residues, but it adds noise to the sampling trajectory.


Pitfall 8: The denoiser only uses the atom-level output during sampling

In edm_sample (L332):

Xa_hat, _, _ = edm_denoise(model, X_noised, t_hat, qe, C, Mr, Mc, hp)
d = (X_noised - Xa_hat) / t_hat

The score estimate d = (X_noised - Xa_hat) / t_hat uses only the atom-level denoised output. The residue and chain denoised outputs are discarded.

In a multiscale system, the residue/chain outputs provide cleaner estimates of large-scale structure (since they operate on coarser, less noisy representations). By ignoring them, the sampler relies entirely on atom-level denoising for all scale information.

This is likely a major contributor to the frozen-residue symptom. The atom-level denoiser may correctly predict local structure (bond lengths, angles) but doesn’t move residues globally because:

  1. It was trained with mostly chain-correlated noise (pitfall 1 & 2), so it handles rigid-body shifts well but not relative rearrangements
  2. The residue/chain outputs which could provide cleaner long-range structure signals are ignored

Fix: Use all three scales in the sampling update

Xa_hat, Xr_hat, Xc_hat = edm_denoise(model, X_noised, t_hat, qe, C, Mr, Mc, hp)
 
# Compute score at each level
d_a = (Xa_noised - Xa_hat) / t_hat
d_r = (Xr_noised - Xr_hat) / t_hat
d_c = (Xc_noised - Xc_hat) / t_hat
 
# Combine: atom-level score + broadcast of higher-level scores
d = d_a + broadcast(d_r, Mar) + broadcast(d_c, Mcr)

Or more simply, replace the atom-level denoised output with a multi-scale composite:

Xa_hat_ms = Xa_hat + broadcast(Xr_hat - reduce(Xa_hat, Mar), Mar) + broadcast(Xc_hat - reduce(Xa_hat, Mcr), Mcr)

This way, residue-level corrections (which are trained to move residues relative to each other) actually influence the sampling trajectory.


Pitfall 9: Loss weighting — all scales weighted equally

Current: loss = (la + lb + lnb + lr + lc) * lw

All scale losses are summed with equal weight before multiplying by the EDM loss weight lw = (σ² + σ_data²) / (σ·σ_data)². This means:

  • At high σ (where the model is learning coarse structure), the chain loss lc dominates because chain noise is largest
  • At low σ (where the model refines local geometry), the atom loss la dominates because the chain/residue positions are already nearly correct

This might seem reasonable, but it means the model never gets a strong gradient signal specifically for residue-relative rearrangement. The residue loss lr is always smaller than lc (since σ_r < σ_c) and comparable to la (since σ_r ≈ σ_a). But la includes the push-pull loss which is much more informative for local geometry.

Recommendation: Give lr a higher weight, or use scale-dependent EMA to track progress per level.


Summary of Root Causes (Ordered by Impact)

#IssueImpactFix Difficulty
1Multiscale noise weights inherited from old sampler; chain dominates (73% of variance)HIGH — model never sees residue-level rearrangement during trainingEasy
2Hierarchical broadcast creates correlated noise; residue-relative displacement is undersampledHIGH — same as #1, different angleMedium
8Sampling discards residue/chain denoised outputsHIGH — even if model learns residue movement, sampling ignores itEasy
5Residue/chain losses are too easy (small σ_r)MEDIUM — model converges on easy residue task, never pushed harderMedium
9Equal loss weighting across scalesMEDIUM — residue rearrangement never gets priorityEasy
7No random augmentation or centering during samplingMEDIUM — compounds other issues via driftEasy
3σ_max possibly too lowLOW — 1280Å initial noise is plenty for 32-residue fragmentsEasy
4Alignment gradient through SVD (potential instability)LOW — detach both sides like BoltzEasy
6Push-pull loss uses aligned GT (gradient variance)LOW — not a correctness issueHard

  1. Rebalance multiscale noise weights. Try w_c=1.0, w_r=4.0, w_a=4.0 (or w_c=2.0, w_r=4.0, w_a=3.0). The chain should still dominate at high noise but residue needs much more relative displacement during training.

  2. Use multi-scale denoised output in sampling. Broadcast residue/chain corrections onto atoms and use the composite as the denoised estimate for the Euler step.

  3. Add random augmentation + centering in sampling (match Boltz’s sampling loop).

  4. Weight lr higher in the loss (e.g., loss = (la + lb + lnb + 3*lr + lc) * lw) to force the residue track to learn non-trivial rearrangements.

  5. Monitor per-level noise statistics during training. Log σ, σ_a, σ_r, σ_c and the actual displacement magnitudes at each level to confirm the hypothesis that residue movement is too small.