# ---------- Newton–Schulz orthogonaliser ----------
def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5) -> Tensor:
a, b, c = 3.4445, -4.7750, 2.0315
X = G.to(torch.bfloat16)
if X.size(-2) > X.size(-1):
X = X.mT
X /= X.norm(dim=(-2, -1), keepdim=True).clamp(min=1e-7)
for _ in range(steps):
A = X @ X.mT
X = a * X + (b * A + c * A @ A) @ X
return (X.mT if G.size(-2) > X.size(-1) else X).to(G.dtype)
# ---------- single-device Muon ----------
class SimpleMuon(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.95,
weight_decay=0.01, nesterov=True, ns_steps=5):
super().__init__(params, dict(lr=lr, momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov, ns_steps=ns_steps))
@torch.no_grad()
def step(self, closure=None):
if closure is not None:
with torch.enable_grad():
closure()
for g in self.param_groups:
lr, mom, wd, nest, k = (g[p] for p in
("lr", "momentum", "weight_decay",
"nesterov", "ns_steps"))
for p in g["params"]:
if p.grad is None:
continue
grad = p.grad.add(p, alpha=wd) if wd else p.grad
buf = self.state.setdefault(p, {}).setdefault(
"momentum_buffer", torch.zeros_like(p))
buf.mul_(mom).add_(grad)
d_p = grad.add(buf, alpha=mom) if nest else buf
if p.ndim == 4:
flat = d_p.view(p.size(0), -1)
d_p = zeropower_via_newtonschulz5(flat, k).view_as(p)
elif p.ndim >= 2:
d_p = zeropower_via_newtonschulz5(d_p, k)
p.add_(d_p, alpha=-lr)
# ---------- synthetic binary-classification data ----------
torch.manual_seed(42)
N, D = 5_000, 100
true_w = torch.randn(D)
X = torch.randn(N, D)
y = torch.bernoulli(torch.sigmoid(X @ true_w)).float() # logistic model :contentReference[oaicite:2]{index=2}
loader = DataLoader(TensorDataset(X, y), batch_size=128, shuffle=True)
# ---------- logistic-regression model ----------
class LogReg(nn.Module):
def __init__(self, dim):
super().__init__()
s = int(math.isqrt(dim))
assert s * s == dim
self.W = nn.Parameter(torch.randn(s, s) * 0.01) # 2-D → Muon path :contentReference[oaicite:3]{index=3}
self.b = nn.Parameter(torch.zeros(()))
def forward(self, x): return torch.sigmoid(x @ self.W.flatten() + self.b)
# ---------- training helper (now takes *list* of optimisers) ----------
def train(model, opts, epochs=15):
loss_fn = nn.BCELoss() # classic but stable :contentReference[oaicite:4]{index=4}
for ep in range(1, epochs + 1):
for xb, yb in loader:
loss = loss_fn(model(xb).squeeze(), yb)
for o in opts: o.zero_grad(set_to_none=True)
loss.backward()
for o in opts: o.step()
if ep % 5 == 0:
with torch.no_grad():
acc = ((model(X).squeeze() > 0.5) == y).float().mean()
print(f"epoch {ep:2d} | loss={loss.item():.4f} | acc={acc:.3f}")
# ---------- run both experiments ----------
init = LogReg(D).state_dict()
print("===> SimpleMuon + AdamW(scalar)")
mu_model = LogReg(D); mu_model.load_state_dict(init)
mu_opt = SimpleMuon([mu_model.W])
sc_opt = torch.optim.AdamW([mu_model.b], lr=3e-4, betas=(0.9, 0.95),
weight_decay=0.01) # AdamW default :contentReference[oaicite:5]{index=5}
train(mu_model, [mu_opt, sc_opt])
print("\n===> AdamW only")
ad_model = LogReg(D); ad_model.load_state_dict(init)
ad_opt = torch.optim.AdamW(ad_model.parameters(), lr=3e-4, betas=(0.9, 0.95),
weight_decay=0.01)
train(ad_model, [ad_opt])