diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py index 3f4f9d8a44d1006ce1030b2b833afc441b08cf6b..9c6c12b483752bf4797a9269b2c403521984f63d 100644 --- a/SuperTagger/Decoder/RNNDecoderLayer.py +++ b/SuperTagger/Decoder/RNNDecoderLayer.py @@ -2,18 +2,11 @@ import random import torch import torch.nn.functional as F -from torch.nn import (Dropout, Module, Module, Sequential, LayerNorm, Dropout, GELU, Linear, LSTM, GRU) +from torch.nn import (Module, Dropout, Linear, LSTM) from Configuration import Configuration from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding -def FFN(d_model, d_ff, dropout_rate = 0.1, d_out = None) -> Module: - return Sequential( - Linear(d_model, d_ff, bias=False), - GELU(), - Dropout(dropout_rate), - Linear(d_ff, d_model if d_out is None else d_out, bias=False) - ) class RNNDecoderLayer(Module): def __init__(self, symbols_map): @@ -45,12 +38,12 @@ class RNNDecoderLayer(Module): # rnn Layer if self.use_attention: self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers, - dropout=dropout, - bidirectional=self.bidirectional, batch_first=True) - else : + dropout=dropout, + bidirectional=self.bidirectional, batch_first=True) + else: self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers, - dropout=dropout, - bidirectional=self.bidirectional, batch_first=True) + dropout=dropout, + bidirectional=self.bidirectional, batch_first=True) # Projection on vocab_size if self.bidirectional: @@ -61,13 +54,6 @@ class RNNDecoderLayer(Module): self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence) self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder) - # linking and pos neg weights - self.linker = - self.positive_transfo = Sequential( - FFN(self.dec_dim * 2, self.dec_dim, 0.1, self.dec_dim//2), LayerNorm(self.dec_dim//2, eps=1e-12)) - self.negative_transfo = Sequential( - FFN(self.dec_dim * 2, self.dec_dim, 0.1, self.dec_dim // 2), LayerNorm(self.dec_dim//2, eps=1e-12)) - def sos_mask(self, y): return torch.eq(y, self.symbols_sos_id) diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py new file mode 100644 index 0000000000000000000000000000000000000000..65682306c2269b022e3ef23f1bd83da9aad19bf1 --- /dev/null +++ b/SuperTagger/Linker/Linker.py @@ -0,0 +1,79 @@ +from itertools import chain + +import torch +from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU + +from Configuration import Configuration + +from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs + + +def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None): + return Sequential( + Linear(d_model, d_ff, bias=False), + GELU(), + Dropout(dropout_rate), + Linear(d_ff, d_model if d_out is None else d_out, bias=False) + ) + + +class Linker: + def __init__(self): + self.__init__() + + self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder']) + self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) + + self.dropout = Dropout(0.1) + + self.pos_transformation = Sequential( + FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), + LayerNorm(self.dim_decoder // 2, eps=1e-12) + ) + self.neg_transformation = Sequential( + FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), + LayerNorm(self.dim_decoder // 2, eps=1e-12) + ) + + def forward(self, symbols_batch, symbols_decoding): + ''' + Parameters : + symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder + ''' + + # some sequential for linker with output of decoder and initial ato + + # decompose into batch_size, max symbols in sentence + decompose_decoding = find_pos_neg_idexes(symbols_batch) + + # get tensors of shape (batch_size, max_symbols_in_sentence/2) + pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding)) + neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding)) + + _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch) + _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch) + + positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0] + negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0] + + distinct_shapes = {tensor.size()[0] for tensor in positives} + distinct_shapes = sorted(distinct_shapes) + + # going to match the pos and neg together + matches = [] + + all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives + if tensor.size()[0] == shape]))) + for shape in distinct_shapes] + + all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives + if tensor.size()[0] == shape]))) + for shape in distinct_shapes] + + for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives): + weights = torch.bmm(this_shape_positives, + this_shape_negatives.transpose(2, 1)) + matches.append(sinkhorn(weights, iters=3)) + + return matches diff --git a/SuperTagger/Linker/Sinkhorn.py b/SuperTagger/Linker/Sinkhorn.py new file mode 100644 index 0000000000000000000000000000000000000000..912abb4a0a070c7eae8af7dd4dd1cf3aafbc3a65 --- /dev/null +++ b/SuperTagger/Linker/Sinkhorn.py @@ -0,0 +1,17 @@ + +from torch import logsumexp + + +def norm(x, dim): + return x - logsumexp(x, dim=dim, keepdim=True) + + +def sinkhorn_step(x): + return norm(norm(x, dim=1), dim=2) + + +def sinkhorn_fn_no_exp(x, tau=1, iters=3): + x = x / tau + for _ in range(iters): + x = sinkhorn_step(x) + return x diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..49e702c77c7b9bc1c400c57711049fdbac15bfe7 --- /dev/null +++ b/SuperTagger/Linker/utils.py @@ -0,0 +1,82 @@ +import re + + +atoms_list = ['r', 'np'] + + +def cut_category_in_symbols(category): + ''' + Parameters : + category : str of kind AtomCat | CategoryCat + Returns : + Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes + ''' + category_to_weights = [] + + if category in atoms_list: + category_to_weights.append(True) + + else: + # dr = / + if category.startswith("dr"): + category_cut = re.search(r'dr\(\d+,(.+),(.+)\)', category) + left_side, right_side = category_cut.group(1), category_cut.group(2) + + # for the left side + if left_side in atoms_list: + category_to_weights.append(False) + else: + category_to_weights += cut_category_in_symbols(left_side) + + # for the right side + if right_side in atoms_list: + category_to_weights.append(True) + else: + category_to_weights += cut_category_in_symbols(right_side) + + # dl = \ + elif category.startswith("dl"): + category_cut = re.search(r'dl\(\d+,(.+),(.+)\)', category) + left_side, right_side = category_cut.group(1), category_cut.group(2) + + # for the left side + if left_side in atoms_list: + category_to_weights.append(True) + else: + category_to_weights += cut_category_in_symbols(left_side) + + # for the right side + if right_side in atoms_list: + category_to_weights.append(False) + else: + category_to_weights += cut_category_in_symbols(right_side) + + return category_to_weights + + +print( cut_category_in_symbols('dr(1,dr(1,r,np),np)')) + + +def find_pos_neg_idexes(batch_symbols): + ''' + Parameters : + batch_symbols : (batch_size, sequence_length) the batch of symbols + + Returns : + (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes + ''' + return None + + +def make_sinkhorn_inputs(bsd_tensor, positional_ids): + """ + :param bsd_tensor: + Tensor of shape (batch size, sequence length, feature dimensionality). + :param positional_ids: + A List (batch_size, max_atoms_in_sentence) . + Each positional_ids[b][a] indexes the location of atoms of type a in sentence b. + :return: + """ + + return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence] + for i, sentence in enumerate(positional_ids)] \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..d6882f3fedce1202da1ebc5fa1b35d6b9cf7c409 --- /dev/null +++ b/test.py @@ -0,0 +1,7 @@ +l = [[False, True, True, False], + [True, False, True, False]] + +print(l) +print([i for i, x in enumerate(l) if x]) + +print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l))) \ No newline at end of file diff --git a/train.py b/train.py index 7f595a9eb61ecd1bbcc092acbe1f6c4f38fbaaba..58ebe4523d2004d52862905df2c09d88aff9dd81 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,6 @@ torch.cuda.empty_cache() # region ParamsModel -max_symbols_in_sentence = int(Configuration.modelDecoderConfig['max_symbols_in_sentence']) max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence']) symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size']) num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers']) @@ -74,7 +73,6 @@ print("##" * 15 + "\nConfiguration : \n") print("ParamsModel\n") -print("\tmax_symbols_in_sentence :", max_symbols_in_sentence) print("\tsymbol_vocab_size :", symbol_vocab_size) print("\tbidirectional : ", False) print("\tnum_gru_layers : ", num_gru_layers) @@ -117,7 +115,7 @@ BASE_TOKENIZER = AutoTokenizer.from_pretrained( 'camembert-base', do_lower_case=True) BASE_MODEL = CamembertModel.from_pretrained("camembert-base") -symbols_tokenizer = SymbolTokenizer(symbol_map, max_symbols_in_sentence, max_len_sentence) +symbols_tokenizer = SymbolTokenizer(symbol_map, max_len_sentence, max_len_sentence) sents_tokenizer = EncoderInput(BASE_TOKENIZER) model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map) model = model.to("cuda" if torch.cuda.is_available() else "cpu")