Skip to content
Snippets Groups Projects
test.py 2.12 KiB
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
import torch


def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    if batch_first:
        out_dims = (len(sequences), max_len) + trailing_dims
    else:
        out_dims = (max_len, len(sequences)) + trailing_dims

    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        # use index notation to prevent duplicate references to the tensor
        if batch_first:
            out_tensor[i, :length, ...] = tensor
        else:
            out_tensor[:length, i, ...] = tensor

    return out_tensor


atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"],
               ["np", "np", "v", "v"]]

atoms_polarity = [[False, True, True, False, False, True, True, False],
                  [True, False, True, False]]

atoms_encoding = torch.randn((2, 8, 24))

matches = []
for atom_type in ["np", "v"]:
    pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
                              x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
    neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
                              not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]

    # to do select with list of list
    pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
            for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
    neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
            for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)

    print(neg_encoding.shape)

    weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
    print(weights.shape)
    print("sinkhorn")
    print(sinkhorn(weights, iters=3).shape)
    matches.append(sinkhorn(weights, iters=3))

print(matches)