diff --git a/SuperTagger/Linker/AtomEmbedding.py b/SuperTagger/Linker/AtomEmbedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e7be599a0fa145f76a5646b83973a3501ed52d4d --- /dev/null +++ b/SuperTagger/Linker/AtomEmbedding.py @@ -0,0 +1,12 @@ +import torch +from torch.nn import Module, Embedding + + +class AtomEmbedding(Module): + def __init__(self, dim_linker, atom_vocab_size, padding_idx=None): + super(AtomEmbedding, self).__init__() + self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx, + scale_grad_by_freq=True) + + def forward(self, x): + return self.emb(x) diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e400d4ef28a90fda4e8e1d5f13276720c1de9fe2 --- /dev/null +++ b/SuperTagger/Linker/AtomTokenizer.py @@ -0,0 +1,51 @@ +import torch + + +class AtomTokenizer(object): + def __init__(self, atom_map, max_atoms_in_sentence): + self.atom_map = atom_map + self.max_atoms_in_sentence = max_atoms_in_sentence + self.inverse_atom_map = {v: k for k, v in self.atom_map.items()} + self.sep_token = '[SEP]' + self.pad_token = '[PAD]' + self.sos_token = '[SOS]' + self.sep_token_id = self.atom_map[self.sep_token] + self.pad_token_id = self.atom_map[self.pad_token] + self.sos_token_id = self.atom_map[self.sos_token] + + def __len__(self): + return len(self.atom_map) + + def convert_atoms_to_ids(self, atom): + return self.atom_map[str(atom)] + + def convert_sents_to_ids(self, sentences): + return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences]) + + def convert_batchs_to_ids(self, batchs_sentences): + return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences], + max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id)) + + def convert_ids_to_atoms(self, ids): + return [self.inverse_atom_map[int(i)] for i in ids] + + +def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): + max_size = sequences[0].size() + trailing_dims = max_size[1:] + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor + diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 65682306c2269b022e3ef23f1bd83da9aad19bf1..745b7d96083aca93f873c325cbcebd34939ecacf 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -4,9 +4,11 @@ import torch from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU from Configuration import Configuration - +from SuperTagger.Linker.AtomEmbedding import AtomEmbedding +from SuperTagger.Linker.AtomTokenizer import AtomTokenizer +from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs +from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None): @@ -24,56 +26,67 @@ class Linker: self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder']) self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) + self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker']) + self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence']) + self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size']) self.dropout = Dropout(0.1) + self.atom_map = atom_map + self.padding_id = self.atom_map['[PAD]'] + self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id) + + # to do : definit un encoding + self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1) + self.pos_transformation = Sequential( - FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), - LayerNorm(self.dim_decoder // 2, eps=1e-12) + FFN(self.dim_decoder, self.dim_decoder, 0.1), + LayerNorm(self.dim_decoder, eps=1e-12) ) self.neg_transformation = Sequential( - FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), - LayerNorm(self.dim_decoder // 2, eps=1e-12) + FFN(self.dim_decoder, self.dim_decoder, 0.1), + LayerNorm(self.dim_decoder, eps=1e-12) ) - def forward(self, symbols_batch, symbols_decoding): + def forward(self, category_batch): ''' Parameters : symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder + Retturns : + link_weights : batch-size, atom_vocab_size, ...) ''' - # some sequential for linker with output of decoder and initial ato - - # decompose into batch_size, max symbols in sentence - decompose_decoding = find_pos_neg_idexes(symbols_batch) - - # get tensors of shape (batch_size, max_symbols_in_sentence/2) - pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding)) - neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding)) - - _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch) - _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch) + # atoms embedding + atoms_batch = get_atoms_batch(category_batch) + atoms_batch = self.atom_tokenizer.convert_batchs_to_ids(atoms_batch) + atoms_embedding = self.atom_embedding(atoms_batch) - positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0] - negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0] + # MHA ou LSTM avec sortie de BERT + # + # TO DO + # atoms_encoding = self.linker_encoder(atoms_embedding) + # + atoms_encoding = atoms_embedding - distinct_shapes = {tensor.size()[0] for tensor in positives} - distinct_shapes = sorted(distinct_shapes) + # find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence) + atoms_polarity = find_pos_neg_idexes(category_batch) - # going to match the pos and neg together - matches = [] + link_weights = [] + for sentence_idx in range(len(atoms_polarity)): + for atom_type in self.atom_map.keys(): + pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + x and atoms_batch[sentence_idx][i] == atom_type] + neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + not x and atoms_batch[sentence_idx][i] == atom_type] - all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives - if tensor.size()[0] == shape]))) - for shape in distinct_shapes] + pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] + neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] - all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives - if tensor.size()[0] == shape]))) - for shape in distinct_shapes] + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) - for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives): - weights = torch.bmm(this_shape_positives, - this_shape_negatives.transpose(2, 1)) - matches.append(sinkhorn(weights, iters=3)) + weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) + link_weights.append(sinkhorn(weights, iters=3)) - return matches + return link_weights diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb55c03f688748485e4452c56dda80e12c73c904 Binary files /dev/null and b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc differ diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95 Binary files /dev/null and b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc differ diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95490adcd0f1a39a9d2181a1fe19b54b9ac7e4be Binary files /dev/null and b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc differ diff --git a/SuperTagger/Linker/atom_map.py b/SuperTagger/Linker/atom_map.py new file mode 100644 index 0000000000000000000000000000000000000000..893fd00518bc8b58bb9b3448243286c8ca14e6b8 --- /dev/null +++ b/SuperTagger/Linker/atom_map.py @@ -0,0 +1,28 @@ +atom_map = \ + {'cl_r': 0, + '\\': 1, + 'n': 2, + 'p': 3, + 's_ppres': 4, + 'dia': 5, + 's_whq': 6, + 'let': 7, + '/': 8, + 's_inf': 9, + 's_pass': 10, + 'pp_a': 11, + 'pp_par': 12, + 'pp_de': 13, + 'cl_y': 14, + 'box': 15, + 'txt': 16, + 's': 17, + 's_ppart': 18, + 's_q': 19, + 'np': 20, + 'pp': 21, + '[SEP]': 22, + '[SOS]': 23, + '[START]': 24, + '[PAD]': 25 + } diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index 49e702c77c7b9bc1c400c57711049fdbac15bfe7..ddb8cb582d60625fb82b651663dc0a732bf6fb7a 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -1,7 +1,30 @@ import re +from SuperTagger.Linker.AtomTokenizer import AtomTokenizer +from SuperTagger.Linker.atom_map import atom_map -atoms_list = ['r', 'np'] + +def get_atoms_from_category(category, category_to_atoms): + if category in atom_map.keys(): + return [category] + else: + category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category) + left_side, right_side = category_cut.group(1), category_cut.group(2) + + category_to_atoms += get_atoms_from_category(left_side, []) + category_to_atoms += get_atoms_from_category(right_side, []) + + return category_to_atoms + + +def get_atoms_batch(category_batch): + batch = [] + for sentence in category_batch: + category_to_atoms = [] + for category in sentence: + category_to_atoms = get_atoms_from_category(category, category_to_atoms) + batch.append(category_to_atoms) + return batch def cut_category_in_symbols(category): @@ -11,10 +34,10 @@ def cut_category_in_symbols(category): Returns : Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes ''' - category_to_weights = [] + category_to_polarity = [] - if category in atoms_list: - category_to_weights.append(True) + if category in atom_map.keys(): + category_to_polarity.append(True) else: # dr = / @@ -23,16 +46,16 @@ def cut_category_in_symbols(category): left_side, right_side = category_cut.group(1), category_cut.group(2) # for the left side - if left_side in atoms_list: - category_to_weights.append(False) + if left_side in atom_map.keys(): + category_to_polarity.append(False) else: - category_to_weights += cut_category_in_symbols(left_side) + category_to_polarity += cut_category_in_symbols(left_side) # for the right side - if right_side in atoms_list: - category_to_weights.append(True) + if right_side in atom_map.keys(): + category_to_polarity.append(True) else: - category_to_weights += cut_category_in_symbols(right_side) + category_to_polarity += cut_category_in_symbols(right_side) # dl = \ elif category.startswith("dl"): @@ -40,21 +63,18 @@ def cut_category_in_symbols(category): left_side, right_side = category_cut.group(1), category_cut.group(2) # for the left side - if left_side in atoms_list: - category_to_weights.append(True) + if left_side in atom_map.keys(): + category_to_polarity.append(True) else: - category_to_weights += cut_category_in_symbols(left_side) + category_to_polarity += cut_category_in_symbols(left_side) # for the right side - if right_side in atoms_list: - category_to_weights.append(False) + if right_side in atom_map.keys(): + category_to_polarity.append(False) else: - category_to_weights += cut_category_in_symbols(right_side) - - return category_to_weights - + category_to_polarity += cut_category_in_symbols(right_side) -print( cut_category_in_symbols('dr(1,dr(1,r,np),np)')) + return category_to_polarity def find_pos_neg_idexes(batch_symbols): @@ -65,18 +85,11 @@ def find_pos_neg_idexes(batch_symbols): Returns : (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes ''' - return None - - -def make_sinkhorn_inputs(bsd_tensor, positional_ids): - """ - :param bsd_tensor: - Tensor of shape (batch size, sequence length, feature dimensionality). - :param positional_ids: - A List (batch_size, max_atoms_in_sentence) . - Each positional_ids[b][a] indexes the location of atoms of type a in sentence b. - :return: - """ + list_batch = [] + for sentence in batch_symbols: + list_symbols = [] + for category in sentence: + list_symbols.append(cut_category_in_symbols(category)) + list_batch.append(list_symbols) + return list_batch - return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence] - for i, sentence in enumerate(positional_ids)] \ No newline at end of file diff --git a/test.py b/test.py index d6882f3fedce1202da1ebc5fa1b35d6b9cf7c409..f208027894f01b95d1509ccd2fafb58b12c2ac44 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,27 @@ -l = [[False, True, True, False], - [True, False, True, False]] +from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +import torch -print(l) -print([i for i, x in enumerate(l) if x]) +atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"], + ["np", "np", "v", "v","np", "np", "v", "v"]] -print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l))) \ No newline at end of file +atoms_polarity = [[False, True, True, False,False, True, True, False], + [True, False, True, False,True, False, True, False]] + +atoms_encoding = torch.randn((2, 8, 24)) + +matches = [] +for sentence_idx in range(len(atoms_polarity)): + + for atom_type in ["np", "v"]: + pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + x and atoms_batch[sentence_idx][i] == atom_type] + neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + not x and atoms_batch[sentence_idx][i] == atom_type] + + pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] + neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] + + weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) + matches.append(sinkhorn(weights, iters=3)) + +print(matches)