from torch import logsumexp


def norm(x, dim):
    return x - logsumexp(x, dim=dim, keepdim=True)


def sinkhorn_step(x):
    return norm(norm(x, dim=1), dim=2)


def sinkhorn_fn_no_exp(x, tau=1, iters=3):
    x = x / tau
    for _ in range(iters):
        x = sinkhorn_step(x)
    return x