Skip to content
Snippets Groups Projects
Commit 85e492e2 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent 0eb359e3
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
Showing
with 99 additions and 416 deletions
import random
import torch
import torch.nn.functional as F
from torch.nn import (Module, Dropout, Linear, LSTM)
from Configuration import Configuration
from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
class RNNDecoderLayer(Module):
def __init__(self, symbols_map):
super(RNNDecoderLayer, self).__init__()
# init params
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
dropout = float(Configuration.modelDecoderConfig['dropout'])
self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
self.symbols_vocab_size = int(Configuration.datasetConfig['symbols_vocab_size'])
self.bidirectional = False
self.use_attention = True
self.symbols_map = symbols_map
self.symbols_padding_id = self.symbols_map["[PAD]"]
self.symbols_sep_id = self.symbols_map["[SEP]"]
self.symbols_start_id = self.symbols_map["[START]"]
self.symbols_sos_id = self.symbols_map["[SOS]"]
# Different layers
# Symbols Embedding
self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size,
padding_idx=self.symbols_padding_id)
# For hidden_state
self.dropout = Dropout(dropout)
# rnn Layer
if self.use_attention:
self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
dropout=dropout,
bidirectional=self.bidirectional, batch_first=True)
else:
self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
dropout=dropout,
bidirectional=self.bidirectional, batch_first=True)
# Projection on vocab_size
if self.bidirectional:
self.proj = Linear(self.dim_encoder * 2, self.symbols_vocab_size)
else:
self.proj = Linear(self.dim_encoder, self.symbols_vocab_size)
self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence)
self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder)
def sos_mask(self, y):
return torch.eq(y, self.symbols_sos_id)
def forward(self, symbols_tokenized_batch, last_hidden_state, pooler_output):
r"""Training the translation from encoded sentences to symbols
Args:
symbols_tokenized_batch: [batch_size, max_len_sentence] the true symbols for each sentence.
last_hidden_state: [batch_size, max_len_sentence, dim_encoder] Sequence of hidden-states at the output of the last layer of the model.
pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
"""
batch_size, sequence_length, hidden_size = last_hidden_state.shape
# y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1
y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
y_hat[:, :, self.symbols_padding_id] = 1
decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
# hidden_state goes through multiple linear layers
hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
use_teacher_forcing = True if random.random() < self.teacher_forcing else False
# for each symbol
for i in range(self.max_len_sentence):
# teacher-forcing training : we pass the target value in the embedding, not a created vector
symbols_embedding = self.symbols_embedder(decoded_i)
symbols_embedding = self.dropout(symbols_embedding)
output = symbols_embedding
if self.use_attention:
attn_weights = F.softmax(
self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
attn_applied = torch.bmm(attn_weights, last_hidden_state)
output = torch.cat((symbols_embedding, attn_applied), 2)
output = self.attn_combine(output)
output = F.relu(output)
# rnn layer
output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
# Projection of the output of the rnn omitting the last probability (which is pad) so we dont predict PAD
proj = self.proj(output)[:, :, :-2]
if use_teacher_forcing:
decoded_i = symbols_tokenized_batch[:, i].unsqueeze(1)
else:
decoded_i = torch.argmax(F.softmax(proj, dim=2), dim=2)
# Calculate sos and pad
sos_mask_i = self.sos_mask(torch.argmax(F.softmax(proj, dim=2), dim=2)[:, -1])
y_hat[~sos_mask, i, self.symbols_padding_id] = 0
y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :]
sos_mask = sos_mask_i | sos_mask
# Stop if every sentence says padding or if we are full
if not torch.any(~sos_mask):
break
return y_hat
def predict_rnn(self, last_hidden_state, pooler_output):
r"""Predicts the symbols from the output of the encoder.
Args:
last_hidden_state: [batch_size, max_len_sentence, dim_encoder] the output of the encoder
pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
"""
batch_size, sequence_length, hidden_size = last_hidden_state.shape
# contains the predictions
y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
y_hat[:, :, self.symbols_padding_id] = 1
# input of the embedder, a created vector that replace the true value
decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
for i in range(self.max_len_sentence):
symbols_embedding = self.symbols_embedder(decoded_i)
symbols_embedding = self.dropout(symbols_embedding)
output = symbols_embedding
if self.use_attention:
attn_weights = F.softmax(
self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
attn_applied = torch.bmm(attn_weights, last_hidden_state)
output = torch.cat((symbols_embedding, attn_applied), 2)
output = self.attn_combine(output)
output = F.relu(output)
output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
proj_softmax = F.softmax(self.proj(output)[:, :, :-2], dim=2)
decoded_i = torch.argmax(proj_softmax, dim=2)
# Set sos and pad
sos_mask_i = self.sos_mask(decoded_i[:, -1])
y_hat[~sos_mask, i, self.symbols_padding_id] = 0
y_hat[~sos_mask, i, :-2] = proj_softmax[~sos_mask, -1, :]
sos_mask = sos_mask_i | sos_mask
# Stop if every sentence says padding or if we are full
if not torch.any(~sos_mask):
break
return y_hat
File deleted
import torch
class EncoderInput():
def __init__(self, tokenizer):
"""@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
self.tokenizer = tokenizer
def fit_transform(self, sents):
return self.tokenizer(sents, padding=True,)
def fit_transform_tensors(self, sents):
temp = self.tokenizer(sents, padding=True, return_tensors='pt', )
return temp['input_ids'], temp['attention_mask']
def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens)
import sys
import torch
from torch import nn
from Configuration import Configuration
class EncoderLayer(nn.Module):
"""Encoder class, imput of supertagger"""
def __init__(self, model):
super(EncoderLayer, self).__init__()
self.name = "Encoder"
self.bert = model
self.hidden_size = self.bert.config.hidden_size
def forward(self, batch):
r"""
Args :
batch: list[str,mask], list of sentences (NOTE: untokenized, continuous sentences)
Returns :
last_hidden_state: [batch_size, max_len_sentence, dim_encoder] Sequence of hidden-states at the output of the last layer of the model.
pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
"""
b_input_ids = batch[0]
b_input_mask = batch[1]
outputs = self.bert(
input_ids=b_input_ids, attention_mask=b_input_mask)
return outputs[0], outputs[1]
@staticmethod
def load(model_path: str):
r""" Load the model from a file.
Args :
model_path (str): path to model
Returns :
model (nn.Module): model with saved parameters
"""
params = torch.load(
model_path, map_location=lambda storage, loc: storage)
args = params['args']
model = EncoderLayer(**args)
model.load_state_dict(params['state_dict'])
return model
def save(self, path: str):
r""" Save the model to a file.
Args :
path (str): path to the model
"""
print('save model parameters to [%s]' % path, file=sys.stderr)
params = {
'args': dict(bert_config=self.bert.config, dropout_rate=self.dropout_rate),
'state_dict': self.state_dict()
}
torch.save(params, path)
def to_dict(self):
return {}
File deleted
File deleted
import torch import torch
from SuperTagger.utils import pad_sequence
class AtomTokenizer(object): class AtomTokenizer(object):
def __init__(self, atom_map, max_atoms_in_sentence): def __init__(self, atom_map, max_atoms_in_sentence):
...@@ -28,24 +30,3 @@ class AtomTokenizer(object): ...@@ -28,24 +30,3 @@ class AtomTokenizer(object):
def convert_ids_to_atoms(self, ids): def convert_ids_to_atoms(self, ids):
return [self.inverse_atom_map[int(i)] for i in ids] return [self.inverse_atom_map[int(i)] for i in ids]
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
...@@ -2,6 +2,7 @@ from itertools import chain ...@@ -2,6 +2,7 @@ from itertools import chain
import torch import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
from torch.nn import Module
from Configuration import Configuration from Configuration import Configuration
from SuperTagger.Linker.AtomEmbedding import AtomEmbedding from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
...@@ -10,11 +11,12 @@ from SuperTagger.Linker.atom_map import atom_map ...@@ -10,11 +11,12 @@ from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer
from SuperTagger.utils import pad_sequence
class Linker(Module):
class Linker:
def __init__(self): def __init__(self):
super(Linker, self).__init__()
self.__init__() self.__init__()
self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
...@@ -71,20 +73,25 @@ class Linker: ...@@ -71,20 +73,25 @@ class Linker:
atoms_polarity = find_pos_neg_idexes(category_batch) atoms_polarity = find_pos_neg_idexes(category_batch)
link_weights = [] link_weights = []
for sentence_idx in range(len(atoms_polarity)): for atom_type in self.atom_map.keys():
for atom_type in self.atom_map.keys(): pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
x and atoms_batch[sentence_idx][i] == atom_type] neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
not x and atoms_batch[sentence_idx][i] == atom_type] range(len(atoms_polarity))]
pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] # to do select with list of list
neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] pos_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
pos_encoding = self.pos_transformation(pos_encoding) for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = pad_sequence(
[atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
# pos_encoding = self.pos_transformation(pos_encoding)
# neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3))
return link_weights return link_weights
from torch import logsumexp from torch import logsumexp
......
No preview for this file type
...@@ -92,4 +92,3 @@ def find_pos_neg_idexes(batch_symbols): ...@@ -92,4 +92,3 @@ def find_pos_neg_idexes(batch_symbols):
list_symbols.append(cut_category_in_symbols(category)) list_symbols.append(cut_category_in_symbols(category))
list_batch.append(list_symbols) list_batch.append(list_symbols)
return list_batch return list_batch
import torch
from torch.nn import Module, Embedding
class SymbolEmbedding(Module):
def __init__(self, dim_decoder, atom_vocab_size, padding_idx):
super(SymbolEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_decoder, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
import torch
class SymbolTokenizer(object):
def __init__(self, symbol_map, max_symbols_in_sentence, max_len_sentence):
self.symbol_map = symbol_map
self.max_symbols_in_sentence = max_symbols_in_sentence
self.max_len_sentence = max_len_sentence
self.inverse_symbol_map = {v: k for k, v in self.symbol_map.items()}
self.sep_token = '[SEP]'
self.pad_token = '[PAD]'
self.sos_token = '[SOS]'
self.sep_token_id = self.symbol_map[self.sep_token]
self.pad_token_id = self.symbol_map[self.pad_token]
self.sos_token_id = self.symbol_map[self.sos_token]
def __len__(self):
return len(self.symbol_map)
def convert_symbols_to_ids(self, symbol):
return self.symbol_map[str(symbol)]
def convert_sents_to_ids(self, sentences):
return torch.as_tensor([self.convert_symbols_to_ids(symbol) for symbol in sentences])
def convert_batchs_to_ids(self, batchs_sentences):
return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
max_len=self.max_symbols_in_sentence, padding_value=self.pad_token_id))
def convert_ids_to_symbols(self, ids):
return [self.inverse_symbol_map[int(i)] for i in ids]
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
File deleted
File deleted
File deleted
symbol_map = \
{'cl_r': 0,
'\\': 1,
'n': 2,
'p': 3,
's_ppres': 4,
'dia': 5,
's_whq': 6,
'let': 7,
'/': 8,
's_inf': 9,
's_pass': 10,
'pp_a': 11,
'pp_par': 12,
'pp_de': 13,
'cl_y': 14,
'box': 15,
'txt': 16,
's': 17,
's_ppart': 18,
's_q': 19,
'np': 20,
'pp': 21,
'[SEP]': 22,
'[SOS]': 23,
'[START]': 24,
'[PAD]': 25
}
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.functional import cross_entropy from torch.nn.functional import nll_loss, cross_entropy
# Another from Kokos function to calculate the accuracy of our predictions vs labels # Another from Kokos function to calculate the accuracy of our predictions vs labels
def measure_supertagging_accuracy(pred, truth, ignore_idx=0): def measure_supertagging_accuracy(pred, truth, ignore_idx=0):
...@@ -42,3 +41,12 @@ class NormCrossEntropy(Module): ...@@ -42,3 +41,12 @@ class NormCrossEntropy(Module):
def forward(self, predictions, truths): def forward(self, predictions, truths):
return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights, return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights,
reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id) reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id)
class SinkhornLoss(Module):
def __init__(self):
super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
for link, perm in zip(predictions, truths))
\ No newline at end of file
...@@ -5,6 +5,26 @@ import torch ...@@ -5,6 +5,26 @@ import torch
from tqdm import tqdm from tqdm import tqdm
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500): def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
print("\n" + "#" * 20) print("\n" + "#" * 20)
print("Loading csv...") print("Loading csv...")
......
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
import torch import torch
atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
["np", "np", "v", "v","np", "np", "v", "v"]]
atoms_polarity = [[False, True, True, False,False, True, True, False], def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
[True, False, True, False,True, False, True, False]] max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
atoms_encoding = torch.randn((2, 8, 24)) out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
matches = []
for sentence_idx in range(len(atoms_polarity)):
for atom_type in ["np", "v"]: atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"],
pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if ["np", "np", "v", "v"]]
x and atoms_batch[sentence_idx][i] == atom_type]
neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
not x and atoms_batch[sentence_idx][i] == atom_type]
pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] atoms_polarity = [[False, True, True, False, False, True, True, False],
neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] [True, False, True, False]]
weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) atoms_encoding = torch.randn((2, 8, 24))
matches.append(sinkhorn(weights, iters=3))
matches = []
for atom_type in ["np", "v"]:
pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
# to do select with list of list
pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
print(neg_encoding.shape)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
print(weights.shape)
print("sinkhorn")
print(sinkhorn(weights, iters=3).shape)
matches.append(sinkhorn(weights, iters=3))
print(matches) print(matches)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment