-
Caroline DE POURTALES authoredCaroline DE POURTALES authored
test.py 2.12 KiB
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
import torch
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"],
["np", "np", "v", "v"]]
atoms_polarity = [[False, True, True, False, False, True, True, False],
[True, False, True, False]]
atoms_encoding = torch.randn((2, 8, 24))
matches = []
for atom_type in ["np", "v"]:
pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
# to do select with list of list
pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
print(neg_encoding.shape)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
print(weights.shape)
print("sinkhorn")
print(sinkhorn(weights, iters=3).shape)
matches.append(sinkhorn(weights, iters=3))
print(matches)