from torch import logsumexp def norm(x, dim): return x - logsumexp(x, dim=dim, keepdim=True) def sinkhorn_step(x): return norm(norm(x, dim=1), dim=2) def sinkhorn_fn_no_exp(x, tau=1, iters=3): x = x / tau for _ in range(iters): x = sinkhorn_step(x) return x