title: EDM Pomodoro Algorithm Review — Potential Pitfalls tags: [inventory, edm, pomodoro, algorithm, debugging, multiscale-noise, critical] created: 2024-04-23 updated: 2024-04-23 status: active related:
EDM Pomodoro Algorithm Review — Potential Pitfalls
Core Issue: Residues Not Moving During Sampling
The observed symptom — residues barely moving during sampling — has multiple likely causes rooted in how the multiscale noise, EDM preconditioning, and Euler sampling interact.
1. Multiscale Noise Decomposition Drastically Under-Noises Residues and Chains
File: debug_training.py:225-244
w_c, w_r, w_a = 8.2, 1.6, 1.3
w_sum = w_c + w_r + w_a # = 11.1
sc = sqrt(8.2/11.1) * sigma # ≈ 0.86 * sigma
sr = sqrt(1.6/11.1) * sigma # ≈ 0.38 * sigma
sa = sqrt(1.3/11.1) * sigma # ≈ 0.34 * sigmaPitfall: The weights are borrowed from the old autoregressive sampler (sampler.py:14-16) where they served a completely different purpose. In the old sampler:
ms2=8.2was a mean shift for chain-level sampling, not an isotropic noise weight- The old sampler also had per-element stochastic widths (
ss0=0.2, ss1=0.4, ss2=3.4) and N^(1/3) scaling — both absent in the EDM version - Most critically, the old sampler’s noise was interpolated between signal and noise (alpha blending), not added to a clean structure
Result at training: The residue-level effective noise is only 0.38σ before broadcast. Since chain noise broadcasts down (which is correct), the total residue noise is roughly sqrt(0.38² + 0.86²) * σ ≈ 0.94σ. But the independent residue noise component is only 0.38σ. This means residues primarily inherit their noise from the chain level, not from their own level.
Result at sampling: At the start, edm_multiscale_noise(zeros, ...) with σ_max = 80 creates:
- Chain offset ~
0.86 * 80 * randn ≈ 68.8 * randn— large global displacements - Residue independent offset ~
0.38 * 80 * randn ≈ 30.4 * randn— smaller - Atom independent offset ~
0.34 * 80 * randn ≈ 27.2 * randn— smallest
The chain-level noise dominates the magnitude, making the initial state look like a globally-shifted blob with minor per-residue and per-atom variation. The model learns that “residue positions are mostly determined by chain position + small perturbation,” and faithfully reproduces this during sampling — hence residues move as a rigid group, not independently.
Recommendation
- Decouple the noise weights from old sampler values. The old weights were designed for a different algorithm.
- Consider making residue and atom noise a larger fraction of total noise, e.g.,
w_a : w_r : w_c = 1 : 1 : 1or even3 : 2 : 1to ensure each level is fully corrupted. - Alternatively, add a per-level σ schedule where residue noise is
σ_r = σ * f_r(σ)with an increasing function so that at high σ levels, residues are fully noised independently.
2. Preconditioning Applied Per-Level After Reduction Is Inconsistent
File: debug_training.py:256-260
Xr_noised = reduce_coordinates(X_noised, Mar)
Xc_noised = reduce_coordinates(Xr_noised, Mrc)
Xr_hat = c_skip(σ) * Xr_noised + c_out(σ) * Xr
Xc_hat = c_skip(σ) * Xc_noised + c_out(σ) * XcPitfall: The noised coordinates at residue/chain levels are obtained by averaging the atom-level noised coordinates. But the EDM preconditioning functions (c_skip, c_out, c_in) were derived assuming a single-scale isotropic Gaussian noise model: x_noised = x_0 + σ·ε where ε ~ N(0,I).
At the atom level, this holds (approximately). But at the residue level, the effective noise is:
Xr_noised = reduce(X0 + noise_atoms, Mar) = X0r + reduce(noise_atoms, Mar)
Because reduce_coordinates averages over the atoms in each residue, the noise variance at the residue level is reduced by the number of atoms per residue. If a residue has ~5 atoms, the residue-level noise variance is σ² / 5, not σ². Applying c_skip(σ) and c_out(σ) as if the noise level were σ is therefore wrong — you’re using the wrong σ for the residue level.
Consequence: The preconditioning is too conservative at residue/chain levels. c_skip(σ) is smaller than it should be (it thinks there’s more noise than there actually is at the reduced level), and c_out(σ) is larger. This over-weights the network output, creating instability in residue/chain predictions.
Recommendation
- Compute the effective σ at each level by accounting for the noise reduction from averaging:
σ_r = σ / sqrt(avg_atoms_per_residue),σ_c = σ / sqrt(avg_residues_per_chain). - Apply
c_skip(σ_r),c_out(σ_r), etc. at the residue level with the effective σ, not the raw σ. - Or: generate noise independently at each level (as now), and use the actual noise level at each level for preconditioning. The current noise decomposition already gives you
srandsc— use those as the σ for each level’s preconditioning.
3. Sampling Starts From Multiscale Noise of Zeros — Not Pure Noise
File: debug_training.py:318
X = edm_multiscale_noise(pt.zeros((N, 3), device=device), Mr, Mc, C, sigmas[0])Pitfall: Boltz and standard EDM start sampling from σ_max · ε (pure isotropic noise). Here, the initial state is constructed as 0 + multiscale_noise(σ_max). Due to the hierarchical broadcast structure, this initial noise is not isotropic — atoms within the same residue share correlated noise from the chain and residue levels.
This means the initial distribution has a different structure than what the model was trained on at σ_max. During training, edm_multiscale_noise(X0, ..., σ) adds hierarchical noise to a real structure, so the noised distribution is X0 + correlated_noise. At σ_max = 80, this should approximate pure noise, but the correlation structure from the broadcast means it’s not purely random — atoms within a residue still have correlated perturbations.
More critically, at σ_max the model sees c_in(σ_max) * X_noised ≈ c_in(80) * (X0 + 80*correlated_noise) ≈ (1/81.6) * 80*correlated_noise ≈ correlated_noise_normalized. The model is trained on this specific correlation structure. When sampling starts with the same structure, it should be consistent — this part is actually correct by design, as long as the multiscale noise function is used consistently in both training and sampling.
However, the concern is that this correlated noise structure biases the model to treat co-located atoms as a group rather than independently, reinforcing the “residues don’t move independently” behavior.
Recommendation
- This is a design choice, not clearly a bug. But if you want residues to move more independently, try starting from isotropic noise (
σ_max * randn(N, 3)) instead. The model will need to learn to handle both correlated and uncorrelated noise, which may require more training. - Alternatively, add random augmentation (rotation + translation) at each sampling step, as Boltz does (
diffusionv2.py:351-357). This breaks the correlation between the global frame and atom positions at each step.
4. Missing Random Augmentation During Sampling
File: debug_training.py:314-339 vs boltz/.../diffusionv2.py:351-358
Boltz applies random rotation + translation at every sampling step:
random_R, random_tr = compute_random_augmentation(multiplicity, device, dtype)
atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True)
atom_coords = torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_trThe pomodoro EDM sampler has no augmentation at all. This creates a subtle problem: the model can overfit to the absolute orientation of the noised structure. Without augmentation, the model implicitly learns the coordinate frame of the data, and during sampling it may get “stuck” in an orientation that doesn’t match any real structure, producing small updates.
Recommendation
- Add centering + random rotation augmentation at each sampling step.
- Also consider adding augmentation during training (Boltz does this too, line 568-569).
5. Loss Composition and Backpropagation Issues
File: debug_training.py:287-301
X0_aligned = superpose(Xa_hat.detach(), X0) # GT aligned to prediction (detached)
la = mean(sum(square(X0_aligned - Xa_hat), dim=1)) # loss on atom coords
...
loss = (la + lb + lnb + lr + lc) * lwPitfall 1 — Alignment detach direction: superpose(Xa_hat.detach(), X0) aligns GT to prediction with predictions detached. This means gradients flow through the loss term X0_aligned - Xa_hat where X0_aligned has no gradient and Xa_hat does. This is correct for computing atom-level MSE loss, but the superposition used in push_pull_loss (line 296-298) operates on X0_aligned vs Xa_hat, creating a coupling where the alignment indirectly depends on Xa_hat through the detach boundary. The push-pull loss doesn’t use the aligned version — it recomputes distances from Xa_hat, D0, R0 which are from the aligned GT. This should be fine.
Pitfall 2 — Relative weighting of levels: loss = (la + lb + lnb + lr + lc) * lw sums all levels with equal weight. But the scales are very different:
lais per-atom MSE — O(N_atoms) termslris per-residue MSE — O(N_residues) termslcis per-chain MSE — O(1-3) terms
With 32 residues × ~5 atoms/residue = ~160 atoms, la could dominate lr by 5× and lc by 50×. The residue and chain losses are nearly negligible. This means the model has almost no gradient signal telling it where residues/chains should be — another reason residues don’t move correctly.
Recommendation
- Weight losses inversely to the number of elements:
loss = (la/N_a + lb + lnb + α_r*lr/N_r + α_c*lc/N_c) * lwwithα_r, α_c >= 1. - Or compute per-element (per-atom, per-residue, per-chain) average losses before summing.
6. Noise Conditioning Is Scalar and Shared Across All Levels
File: debug_training.py:249-251
t = c_noise(σ, hp)
qe_t = cat([qe, t.expand(qe.shape[0], 1)], dim=-1) # 31st feature, same for all atomsThe noise level conditioning c_noise(σ) is a single scalar concatenated to every atom’s feature vector. All atoms, residues, and chains see the same noise level.
Pitfall: Due to the multiscale noise decomposition, the effective noise level differs between atom/residue/chain levels. The atom level has effective variance σ_a² + σ_r²/n_a + σ_c²/(n_a·n_r) (from noise + broadcast), while the residue level has σ_r² + σ_c²/n_r. But the model sees c_noise(σ) everywhere — it doesn’t know that the effective noise differs per level.
This is a major information asymmetry: the model must learn different denoising strategies for each level but only receives one global noise signal. The model can presumably infer the level-dependent noise from the input coordinates (since c_in(σ) * X_noised has different variances at different levels), but this puts an unnecessary burden on the model.
Recommendation
- Provide per-level noise conditioning: pass
c_noise(σ),c_noise(σ_r),c_noise(σ_c)as separate features to the atom, residue, and chain scale tracks. - Or embed the noise level as a 3-dimensional vector
[c_noise(σ), c_noise(σ_r), c_noise(σ_c)].
7. Model Architecture: Residue/Chain Scale Tracks Get Zero Input Features
File: debug_training.py:391-404
stm_r=ScaleTrackModel(sem=StateEmbeddingModel.Config(F=1, S=S, ...))
stm_c=ScaleTrackModel(sem=StateEmbeddingModel.Config(F=1, S=S, ...))The residue and chain scale tracks have F=1 — a single feature dimension. Looking at the model forward (model.py:645-646):
qr = zeros((Xr.shape[0], self.Fr), device=qa.device) # F=1 → ZEROS
qc = zeros((Xc.shape[0], self.Fc), device=qa.device) # F=1 → ZEROSPitfall: Residue and chain scalar tracks start from zero features (no identity information). They only receive information through cross-attention from the atom level (gtl_c[i]). This means:
- The residue track has no direct signal about which residues are which — it learns identity only through coupling with atom features
- The chain track has even less information — it only aggregates from residues which themselves have no identity
- The noise conditioning (
F=1) is the only non-zero scalar input to these tracks
This is intentional (the model should learn structure from geometry), but combined with Pitfall 6, it means the residue/chain tracks have very limited information to distinguish noise levels and respond appropriately.
Additionally: The atom track has F=31 (30 element types + 1 noise level). The residue and chain tracks have F=1 (just noise level). This means the residue and chain tracks see their geometry through c_in(σ) * X_noised_reduced but have minimal scalar context. The model’s capacity to learn residue-level denoising is severely limited.
Recommendation
- Increase
Ffor residue and chain tracks. Pass residue type embeddings (e.g., amino acid identity) and chain metadata. - At minimum, pass different noise conditioning per level as discussed in Pitfall 6.
8. Euler Step Scale and Schedule Issues
File: debug_training.py:334-335
d = (X_noised - Xa_hat) / t_hat # score direction
X = X_noised + step_scale * (s_next - t_hat) * d # Euler stepWith step_scale=1.0 and gamma=0.0, this is the deterministic Euler ODE solver.
Pitfall 1: The schedule uses σ_max = 80 (not 160 as in the session notes, but the session notes say sigma_max=160 while the code says sigma_max=80.0). Boltz uses σ_max=160. With σ_max=80 and rho=7, the schedule may not spend enough steps at high noise levels where residue positions need to be established.
Pitfall 2: The sample_schedule in the code (line 32-37) multiplies by σ_data at the end: sigmas = sigmas * σ_data. This matches Boltz (line 290). But the noise distribution in training (line 196) also uses σ_data as a multiplier: σ = σ_data * exp(P_mean + P_std * randn). So σ_data=16 scales both. The median training σ is 16 * exp(-1.2) ≈ 4.8, and the schedule goes from 80 * 16 = 1280 down to 1e-3 * 16 = 0.016. Wait — that means training never sees σ = 1280, but sampling starts at σ = 1280?
Let me re-read: sigma_max=80.0 is the raw parameter. sample_schedule computes raw sigmas then multiplies by sigma_data=16, giving a max of 80 * 16 = 1280. But noise_distribution returns sigma_data * exp(P_mean + P_std * randn), where a 3σ event gives 16 * exp(-1.2 + 4.5) ≈ 16 * 27.1 ≈ 434. So sampling starts at σ=1280 but the model has never been trained at that level!
This is a critical mismatch: the model’s first denoising step at σ=1280 is operating completely out of distribution. The model will likely return near-zero output (since it’s never seen this σ), and c_skip(1280) * X_noised ≈ 0.0002 * X_noised, so Xa_hat is essentially garbage. The Euler step then computes a near-random direction and takes a tiny step.
Recommendation
- Fix the schedule: Either use
sigma_maxwithout theσ_datamultiplier in the schedule, or alignsigma_maxsuch thatsigma_max * σ_datais within the training distribution’s range. - Boltz’s
sigma_max=160withsigma_data=16gives schedule max =160 * 16 = 2560, but Boltz trains withσ = σ_data * exp(...)which can reach these values at 3σ+ events. The key difference is that Boltz’s training distribution has a wider effective range due to the per-sample σ, and the model is trained for longer with more data, so it handles extreme σ better. - For pomodoro: Lower
sigma_maxto something like10so the schedule max is10 * 16 = 160, which is reachable in training. Or don’t multiply byσ_datain the schedule.
Summary of Likely Root Causes for “Residues Not Moving”
| # | Issue | Severity | Description |
|---|---|---|---|
| 1 | Noise weights from old sampler | Critical | Residue independent noise is only 0.38σ — too low to fully corrupt residue positions |
| 8 | Schedule/training σ mismatch | Critical | Sampling starts at σ=1280 but training never exceeds σ≈400; first denoising step is out of distribution |
| 2 | Wrong preconditioning at reduced levels | High | Using raw σ for residue/chain preconditioning when effective noise is lower due to averaging |
| 5 | Loss weighting ignores level sizes | High | Chain loss (1-3 terms) atom loss (~160 terms) — chain/residue gradients are negligible |
| 4 | No random augmentation | Medium | Model can overfit to absolute orientation; Boltz augments every step |
| 6 | Single noise conditioning for all levels | Medium | Model doesn’t know per-level effective noise |
| 7 | Residue/chain tracks have near-zero input features | Medium | F=1 with zero initialization limits capacity for residue-level reasoning |
| 3 | Correlated initial noise | Low | By design, but reinforces group behavior |
Priority Fixes
- Fix noise weights: Increase residue noise fraction. Try
w_a : w_r : w_c = 3 : 2 : 1or even1 : 1 : 1. - Fix schedule σ mismatch: Either stop multiplying by
σ_datainsample_schedule, or reducesigma_maxso the schedule stays within training support. - Use per-level σ for preconditioning: In
edm_denoise, applyc_skip(sr),c_out(sr)etc. for residue level,c_skip(sc)for chain level. - Add loss weighting by level size: Weight
lrandlcinversely to the number of residues/chains. - Add random augmentation: Center + random rotation at each sampling step.