Work in Progress

Zero-Overhead Adaptive Preconditioning via Sinkhorn Geometry

SINK-V
Korata Hiu*
* nickname — adv_optm dev
Abstract

SinkGD is an optimization algorithm that iteratively balances noise and signal by geometrically preconditioning gradient rows and columns to a uniform norm. While it effectively prevents vanishing or exploding updates across channels, standard SinkGD fundamentally lacks an adaptive variance mechanism to dampen noisy dimensions—a critical mechanism for stable convergence.

In this post, we propose two techniques to introduce precise variance preconditioning to SinkGD with zero memory overhead. First, we define Normalization-then-Momentum (NtM), a structural shift that applies Sinkhorn normalization before momentum accumulation. Second, building strictly upon NtM's geometrical constraints, we introduce Sinkhorn Implicit Variance (SINK-V). By guaranteeing that all incoming gradients lie on a unit hypersphere, we prove that exact, spatially-aware variance can be extracted dynamically from the momentum buffer's norm. This yields robust adaptivity without allocating any explicit trackers.

§ 0 Standard SinkGD & The Missing Variance

Standard SinkGD introduces an iterative, geometric structural row and column pre-conditioning method designed to balance noise and signal propagation dynamically across network dimensions.

By calculating rank-1 diagonal pre-conditioning factors on-the-fly, the algorithm ensures equitable distribution of energy and guarantees that every row and column consistently contributes to the network’s learning process.

In a traditional SinkGD implementation, the optimizer accumulates standard gradients into a momentum buffer, and then applies the Sinkhorn normalization to that buffer to compute the final update:

$$ \mathbf{m}_t = \beta \mathbf{m}_{t-1} + (1 - \beta) \mathbf{g}_t $$ $$ \Delta \mathbf{w}_t = \text{Sinkhorn}(\mathbf{m}_t) $$
Algorithm — Standard SinkGD
// 1. Accumulate raw gradient into momentum
m_t.lerp_(grad, 1 - beta)

// 2. Apply Sinkhorn preconditioning to momentum
update = apply_sr_sinkhorn(m_t, iters=5)

// 3. Apply update
param.sub_(update * lr)

In deep learning, scaling updates inversely to their variance allows an optimizer to take safely large steps in consistent directions, while aggressively dampening steps in noisy ones. Standard SinkGD lacks this mechanism. Because the Sinkhorn operator is the final step before the update, every row and column is forced to have an identical structural norm, destroying any magnitude-based confidence metric.

§ 1 Normalization-then-Momentum (NtM)

To solve this, our first proposal is a simple but structural reordering of operations called Normalization-then-Momentum (NtM). Instead of normalizing the momentum buffer, we normalize the raw gradient first, and accumulate these normalized geometries into the momentum buffer.

$$ \hat{\mathbf{g}}_t = \text{Sinkhorn}(\mathbf{g}_t) $$ $$ \mathbf{m}_t = \beta \mathbf{m}_{t-1} + (1 - \beta) \hat{\mathbf{g}}_t $$ $$ \Delta \mathbf{w}_t = \mathbf{m}_t $$
Algorithm — Normalization-then-Momentum (NtM)
// 1. Apply Sinkhorn to the incoming raw gradient
g_norm = apply_sr_sinkhorn(grad, iters=5)

// 2. Accumulate normalized gradient into momentum
m_t.lerp_(g_norm, 1 - beta)

// 3. Apply momentum as the update
param.sub_(m_t * lr)

This reordering is crucial. By enforcing the Sinkhorn constraint on $\mathbf{g}_t$ directly, we guarantee that the root mean square (RMS) energy of every incoming gradient row and column is strictly $1.0$. As these bounded vectors are exponentially averaged into $\mathbf{m}_t$, the magnitude of $\mathbf{m}_t$ organically shrinks when the gradient directions are noisy, and grows towards $1.0$ when directions are consistent. This sets the mathematical foundation for SINK-V.

§ 2 The Mathematical Proof (SINK-V)

We now establish Sinkhorn Implicit Variance (SINK-V). Let $\hat{\mathbf{g}}_t \in \mathbb{R}^d$ be a Sinkhorn-normalized gradient row at step $t$ under the NtM regime. Let $\mathbf{m}_t$ be the momentum buffer for that row, acting as the expected value: $\mathbf{m}_t \approx \mathbb{E}[\hat{\mathbf{g}}_t]$.

We wish to find the true spatial variance $V$ of the normalized gradient around its momentum:

Definition of Variance
$$V = \mathbb{E}_{\text{time}} \left[ \frac{1}{d} \sum_{j=1}^{d} (\hat{\mathbf{g}}_{t,j} - \mathbf{m}_{t,j})^2 \right]$$
Expanding the square
$$V = \mathbb{E}_{\text{time}} \left[ \frac{1}{d} \sum_{j=1}^{d} (\hat{\mathbf{g}}_{t,j}^2 - 2 \hat{\mathbf{g}}_{t,j} \mathbf{m}_{t,j} + \mathbf{m}_{t,j}^2) \right]$$
Applying the Sinkhorn Constraint (RMS = 1.0)
Because of NtM Sinkhorn normalization, the "total energy" of the row is structurally fixed: $\frac{1}{d} \sum \hat{\mathbf{g}}_{t,j}^2 = 1$. Substituting this and distributing the expectation: $$V = 1 - 2 \left( \frac{1}{d} \sum_{j=1}^{d} \mathbf{m}_{t,j} \mathbb{E}[\hat{\mathbf{g}}_{t,j}] \right) + \frac{1}{d} \sum_{j=1}^{d} \mathbf{m}_{t,j}^2$$
Substituting the Momentum Mean
Since $\mathbf{m}_t$ is the expected value of $\hat{\mathbf{g}}_t$, we replace $\mathbb{E}[\hat{\mathbf{g}}_{t,j}]$ with $\mathbf{m}_{t,j}$: $$V = 1 - 2 \left( \frac{1}{d} \sum_{j=1}^{d} \mathbf{m}_{t,j}^2 \right) + \frac{1}{d} \sum_{j=1}^{d} \mathbf{m}_{t,j}^2$$
Exact Implicit Row Variance
$$V_{\text{row}} = 1 - \frac{1}{d} \sum_{j=1}^{d} \mathbf{m}_{t,j}^2$$

This reveals a profound property: The expected variance of the gradient is an exact function of the spatial mean of the squared momentum elements. We extract the variance dynamically from the momentum buffer that we already track, with zero explicit variance state.

§ 3 The Geometric Concept

We can visualize this exact equality using basic geometry.

Because NtM Sinkhorn Normalization enforces an RMS norm of 1, every incoming gradient vector $\hat{\mathbf{g}}_t$ is forced to live on the surface of a $d$-dimensional hypersphere of radius $R = \sqrt{d}$.

The momentum buffer $\mathbf{m}_t$ is the exponentially-weighted average of these surface vectors. By Jensen's Inequality, $\mathbf{m}_t$ will always fall inside the hypersphere. By the Pythagorean theorem of expected values, the relationship between the origin, the momentum, and the gradient surface is fixed:

$$\text{Signal}^2 + \text{Variance}^2 = \text{Total Energy}^2$$ $$||\mathbf{m}||_{\text{RMS}}^2 + V = 1.0 \qquad \implies \qquad V = 1.0 - ||\mathbf{m}||_{\text{RMS}}^2$$
Mathematical safety guarantee: Because it is geometrically impossible for the average of vectors on a unit hypersphere to exceed an RMS norm of 1, `1.0 - mean(m^2)` is mathematically guaranteed to be $\ge 0$. There is no risk of negative variances destabilizing the optimizer.
Fig. 1 — Geometric Constraint
Because Total Energy (RMS) is bounded to 1.0, as the Signal (Momentum squared) increases, Variance must decrease exactly proportionally.
Fig. 2 — Update Magnitude Scaling & atan2 Bounding
Comparing Pure NtM vs. SINK-V Preconditioning. Using atan2 with the $4/\pi$ multiplier safely non-linearizes the update, bounding max output exactly to $[-2, 2]$.

§ 4 The Implementation

In standard adaptive optimizers like Adam, the update is computed using simple division: $\mathbf{m} / \sqrt{v}$. This division is notoriously unstable when variance approaches zero, requiring hyperparameter tuning of a small $\epsilon$ denominator to prevent explosion.

In SINK-V, we replace division entirely with the 2-argument arctangent: $\text{atan2}(\mathbf{m}, \sqrt{v})$. Because $\mathbf{m}$ and $v$ are geometrically strictly bound, $\text{atan2}$ cleanly scales confident signals while bounding the raw maximum step mathematically to $\pm \pi/2$. By multiplying the output by $4/\pi$, we map the effective update perfectly to a bounded $[-2, 2]$ range. This guarantees absolute numerical stability—zero division errors, and no $\epsilon$ tuning

Algorithm — Full NtM + SINK-V Optimizer Step
// 1. Normalization-then-Momentum (NtM)
g_norm = apply_sr_sinkhorn(grad, iters=5)
m_t.lerp_(g_norm, 1 - beta)

// 2. SINK-V: Compute the squared momentum buffer
m_2d_sq = m_t.view(m_t.shape[0], -1).square()

// 3. SINK-V: Extract mathematically exact spatial variance (Row & Col)
vt_row = (1.0 - m_2d_sq.mean(dim=-1)).clamp_min_(1e-30)
vt_col = (1.0 - m_2d_sq.mean(dim=-2)).clamp_min_(1e-30)

// 4. Apply geometrically bounded variance preconditioning
// We replace standard division (m / sqrt(v)) with atan2.
// Multiplying by 4/pi scales the maximum update exactly to [-2, 2].
denom = torch.sqrt(vt_row.view(-1, 1) * vt_col.view(1, -1))
update = torch.atan2(m_t, denom).mul_(4 / math.pi)

// 5. Update weights
param.sub_(update * lr)

System Advantages