Skip to content
Snippets Groups Projects
Commit c7baf521 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

progress on linker

parent cd0e359d
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
import torch
from torch.nn import Module, Embedding
class AtomEmbedding(Module):
def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
super(AtomEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
import torch
class AtomTokenizer(object):
def __init__(self, atom_map, max_atoms_in_sentence):
self.atom_map = atom_map
self.max_atoms_in_sentence = max_atoms_in_sentence
self.inverse_atom_map = {v: k for k, v in self.atom_map.items()}
self.sep_token = '[SEP]'
self.pad_token = '[PAD]'
self.sos_token = '[SOS]'
self.sep_token_id = self.atom_map[self.sep_token]
self.pad_token_id = self.atom_map[self.pad_token]
self.sos_token_id = self.atom_map[self.sos_token]
def __len__(self):
return len(self.atom_map)
def convert_atoms_to_ids(self, atom):
return self.atom_map[str(atom)]
def convert_sents_to_ids(self, sentences):
return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
def convert_batchs_to_ids(self, batchs_sentences):
return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
def convert_ids_to_atoms(self, ids):
return [self.inverse_atom_map[int(i)] for i in ids]
def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
max_size = sequences[0].size()
trailing_dims = max_size[1:]
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
......@@ -4,9 +4,11 @@ import torch
from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
from Configuration import Configuration
from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs
from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None):
......@@ -24,56 +26,67 @@ class Linker:
self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker'])
self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence'])
self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size'])
self.dropout = Dropout(0.1)
self.atom_map = atom_map
self.padding_id = self.atom_map['[PAD]']
self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id)
# to do : definit un encoding
self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1)
self.pos_transformation = Sequential(
FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
LayerNorm(self.dim_decoder // 2, eps=1e-12)
FFN(self.dim_decoder, self.dim_decoder, 0.1),
LayerNorm(self.dim_decoder, eps=1e-12)
)
self.neg_transformation = Sequential(
FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
LayerNorm(self.dim_decoder // 2, eps=1e-12)
FFN(self.dim_decoder, self.dim_decoder, 0.1),
LayerNorm(self.dim_decoder, eps=1e-12)
)
def forward(self, symbols_batch, symbols_decoding):
def forward(self, category_batch):
'''
Parameters :
symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder
Retturns :
link_weights : batch-size, atom_vocab_size, ...)
'''
# some sequential for linker with output of decoder and initial ato
# decompose into batch_size, max symbols in sentence
decompose_decoding = find_pos_neg_idexes(symbols_batch)
# get tensors of shape (batch_size, max_symbols_in_sentence/2)
pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding))
neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding))
_positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch)
_negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch)
# atoms embedding
atoms_batch = get_atoms_batch(category_batch)
atoms_batch = self.atom_tokenizer.convert_batchs_to_ids(atoms_batch)
atoms_embedding = self.atom_embedding(atoms_batch)
positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0]
negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0]
# MHA ou LSTM avec sortie de BERT
#
# TO DO
# atoms_encoding = self.linker_encoder(atoms_embedding)
#
atoms_encoding = atoms_embedding
distinct_shapes = {tensor.size()[0] for tensor in positives}
distinct_shapes = sorted(distinct_shapes)
# find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence)
atoms_polarity = find_pos_neg_idexes(category_batch)
# going to match the pos and neg together
matches = []
link_weights = []
for sentence_idx in range(len(atoms_polarity)):
for atom_type in self.atom_map.keys():
pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
x and atoms_batch[sentence_idx][i] == atom_type]
neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
not x and atoms_batch[sentence_idx][i] == atom_type]
all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives
if tensor.size()[0] == shape])))
for shape in distinct_shapes]
pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives
if tensor.size()[0] == shape])))
for shape in distinct_shapes]
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives):
weights = torch.bmm(this_shape_positives,
this_shape_negatives.transpose(2, 1))
matches.append(sinkhorn(weights, iters=3))
weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
link_weights.append(sinkhorn(weights, iters=3))
return matches
return link_weights
File added
File added
File added
atom_map = \
{'cl_r': 0,
'\\': 1,
'n': 2,
'p': 3,
's_ppres': 4,
'dia': 5,
's_whq': 6,
'let': 7,
'/': 8,
's_inf': 9,
's_pass': 10,
'pp_a': 11,
'pp_par': 12,
'pp_de': 13,
'cl_y': 14,
'box': 15,
'txt': 16,
's': 17,
's_ppart': 18,
's_q': 19,
'np': 20,
'pp': 21,
'[SEP]': 22,
'[SOS]': 23,
'[START]': 24,
'[PAD]': 25
}
import re
from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map
atoms_list = ['r', 'np']
def get_atoms_from_category(category, category_to_atoms):
if category in atom_map.keys():
return [category]
else:
category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
left_side, right_side = category_cut.group(1), category_cut.group(2)
category_to_atoms += get_atoms_from_category(left_side, [])
category_to_atoms += get_atoms_from_category(right_side, [])
return category_to_atoms
def get_atoms_batch(category_batch):
batch = []
for sentence in category_batch:
category_to_atoms = []
for category in sentence:
category_to_atoms = get_atoms_from_category(category, category_to_atoms)
batch.append(category_to_atoms)
return batch
def cut_category_in_symbols(category):
......@@ -11,10 +34,10 @@ def cut_category_in_symbols(category):
Returns :
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
'''
category_to_weights = []
category_to_polarity = []
if category in atoms_list:
category_to_weights.append(True)
if category in atom_map.keys():
category_to_polarity.append(True)
else:
# dr = /
......@@ -23,16 +46,16 @@ def cut_category_in_symbols(category):
left_side, right_side = category_cut.group(1), category_cut.group(2)
# for the left side
if left_side in atoms_list:
category_to_weights.append(False)
if left_side in atom_map.keys():
category_to_polarity.append(False)
else:
category_to_weights += cut_category_in_symbols(left_side)
category_to_polarity += cut_category_in_symbols(left_side)
# for the right side
if right_side in atoms_list:
category_to_weights.append(True)
if right_side in atom_map.keys():
category_to_polarity.append(True)
else:
category_to_weights += cut_category_in_symbols(right_side)
category_to_polarity += cut_category_in_symbols(right_side)
# dl = \
elif category.startswith("dl"):
......@@ -40,21 +63,18 @@ def cut_category_in_symbols(category):
left_side, right_side = category_cut.group(1), category_cut.group(2)
# for the left side
if left_side in atoms_list:
category_to_weights.append(True)
if left_side in atom_map.keys():
category_to_polarity.append(True)
else:
category_to_weights += cut_category_in_symbols(left_side)
category_to_polarity += cut_category_in_symbols(left_side)
# for the right side
if right_side in atoms_list:
category_to_weights.append(False)
if right_side in atom_map.keys():
category_to_polarity.append(False)
else:
category_to_weights += cut_category_in_symbols(right_side)
return category_to_weights
category_to_polarity += cut_category_in_symbols(right_side)
print( cut_category_in_symbols('dr(1,dr(1,r,np),np)'))
return category_to_polarity
def find_pos_neg_idexes(batch_symbols):
......@@ -65,18 +85,11 @@ def find_pos_neg_idexes(batch_symbols):
Returns :
(batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
'''
return None
def make_sinkhorn_inputs(bsd_tensor, positional_ids):
"""
:param bsd_tensor:
Tensor of shape (batch size, sequence length, feature dimensionality).
:param positional_ids:
A List (batch_size, max_atoms_in_sentence) .
Each positional_ids[b][a] indexes the location of atoms of type a in sentence b.
:return:
"""
list_batch = []
for sentence in batch_symbols:
list_symbols = []
for category in sentence:
list_symbols.append(cut_category_in_symbols(category))
list_batch.append(list_symbols)
return list_batch
return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence]
for i, sentence in enumerate(positional_ids)]
\ No newline at end of file
l = [[False, True, True, False],
[True, False, True, False]]
from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
import torch
print(l)
print([i for i, x in enumerate(l) if x])
atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
["np", "np", "v", "v","np", "np", "v", "v"]]
print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l)))
\ No newline at end of file
atoms_polarity = [[False, True, True, False,False, True, True, False],
[True, False, True, False,True, False, True, False]]
atoms_encoding = torch.randn((2, 8, 24))
matches = []
for sentence_idx in range(len(atoms_polarity)):
for atom_type in ["np", "v"]:
pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
x and atoms_batch[sentence_idx][i] == atom_type]
neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
not x and atoms_batch[sentence_idx][i] == atom_type]
pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
matches.append(sinkhorn(weights, iters=3))
print(matches)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment