diff --git a/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5c1a1c95998b40390a6839485e680f5d79bacf --- /dev/null +++ b/Linker/AtomTokenizer.py @@ -0,0 +1,50 @@ +import torch +from utils import pad_sequence + + +class AtomTokenizer(object): + r""" + Tokenizer for the atoms with padding + """ + def __init__(self, atom_map, max_atoms_in_sentence): + self.atom_map = atom_map + self.max_atoms_in_sentence = max_atoms_in_sentence + self.inverse_atom_map = {v: k for k, v in self.atom_map.items()} + self.pad_token = '[PAD]' + self.pad_token_id = self.atom_map[self.pad_token] + + def __len__(self): + return len(self.atom_map) + + def convert_atoms_to_ids(self, atom): + r""" + Convert a atom to its id + :param atom: atom string + :return: atom id + """ + return self.atom_map[str(atom)] + + def convert_sents_to_ids(self, sentences): + r""" + Convert sentences to ids + :param sentences: List of atoms in a sentence + :return: List of atoms'ids + """ + return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences]) + + def convert_batchs_to_ids(self, batchs_sentences): + r""" + Convert a batch of sentences of atoms to the ids + :param batchs_sentences: batch of sentences atoms + :return: list of list of atoms'ids + """ + return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences], + max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id)) + + def convert_ids_to_atoms(self, ids): + r""" + Translate id to atom + :param ids: atom id + :return: atom string + """ + return [self.inverse_atom_map[int(i)] for i in ids] diff --git a/Linker/Linker.py b/Linker/Linker.py new file mode 100644 index 0000000000000000000000000000000000000000..76596f0cda1fa404a958e693bd84d1b673264f80 --- /dev/null +++ b/Linker/Linker.py @@ -0,0 +1,502 @@ +import datetime +import math +import os +import sys +import time + +import torch +import torch.nn.functional as F +from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \ + Embedding, GELU +from torch.optim import AdamW +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import TensorDataset, random_split +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from Configuration import Configuration +from .AtomTokenizer import AtomTokenizer +from .PositionalEncoding import PositionalEncoding +from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from Linker.atom_map import atom_map, atom_map_redux +from Linker.eval import measure_accuracy, SinkhornLoss +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \ + find_pos_neg_idexes, get_num_atoms_batch +from SuperTagger import SuperTagger +from utils import pad_sequence + + +def format_time(elapsed): + ''' + Takes a time in seconds and returns a string hh:mm:ss + ''' + # Round to the nearest second. + elapsed_rounded = int(round(elapsed)) + + # Format as hh:mm:ss + return str(datetime.timedelta(seconds=elapsed_rounded)) + + +def output_create_dir(): + """ + Create le output dir for tensorboard and checkpoint + @return: output dir, tensorboard writter + """ + from datetime import datetime + outpout_path = 'TensorBoard' + training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) + logs_dir = os.path.join(training_dir, 'logs') + writer = SummaryWriter(log_dir=logs_dir) + return training_dir, writer + + +def generate_square_subsequent_mask(sz): + """Generates an upper-triangular matrix of -inf, with zeros on diag.""" + return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) + + +class Linker(Module): + def __init__(self, supertagger_path_model): + super(Linker, self).__init__() + + # region parameters + config = Configuration.read_config() + datasetConfig = config["DATASET_PARAMS"] + modelEncoderConfig = config["MODEL_ENCODER"] + modelLinkerConfig = config["MODEL_LINKER"] + modelTrainingConfig = config["MODEL_TRAINING"] + + dim_encoder = int(modelEncoderConfig['dim_encoder']) + # atom settings + atom_vocab_size = int(datasetConfig['atom_vocab_size']) + # Transformer + self.nhead = int(modelLinkerConfig['nhead']) + self.dim_emb_atom = int(modelLinkerConfig['dim_emb_atom']) + self.dim_feedforward_transformer = int(modelLinkerConfig['dim_feedforward_transformer']) + self.num_layers = int(modelLinkerConfig['num_layers']) + # torch cat + dropout = float(modelLinkerConfig['dropout']) + self.dim_cat_out = int(modelLinkerConfig['dim_cat_out']) + dim_intermediate_FFN = int(modelLinkerConfig['dim_intermediate_FFN']) + dim_pre_sinkhorn_transfo = int(modelLinkerConfig['dim_pre_sinkhorn_transfo']) + # sinkhorn + self.sinkhorn_iters = int(modelLinkerConfig['sinkhorn_iters']) + # settings + self.max_len_sentence = int(datasetConfig['max_len_sentence']) + self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence']) + self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type']) + learning_rate = float(modelTrainingConfig['learning_rate']) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # endregion + + # SuperTagger for categories + supertagger = SuperTagger() + supertagger.load_weights(supertagger_path_model) + self.Supertagger = supertagger + self.Supertagger.model.to(self.device) + + # Atoms embedding + self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.atom_map_redux = atom_map_redux + self.padding_id = atom_map["[PAD]"] + self.sub_atoms_type_list = list(atom_map_redux.keys()) + self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id) + self.atom_encoder.weight.data.uniform_(-0.1, 0.1) + self.position_encoder = PositionalEncoding(self.dim_emb_atom, dropout, max_len=self.max_atoms_in_sentence) + encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead, + dim_feedforward=self.dim_feedforward_transformer, dropout=dropout) + self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers) + + # Concatenation with word embedding + dim_cat = dim_encoder + self.dim_emb_atom + self.linker_encoder = Sequential( + Linear(dim_cat, self.dim_cat_out), + GELU(), + Dropout(dropout), + LayerNorm(self.dim_cat_out, eps=1e-8) + ) + + # Division into positive and negative + self.pos_transformation = Sequential( + FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo), + LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8) + ) + self.neg_transformation = Sequential( + FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo), + LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8) + ) + + # Learning + self.cross_entropy_loss = SinkhornLoss() + self.optimizer = AdamW(self.parameters(), + lr=learning_rate) + self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5) + + self.to(self.device) + + def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): + r""" + Args: + batch_size : int + df_axiom_links pandas DataFrame + validation_rate + Returns: + the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask + """ + print("Start preprocess Data") + sentences_batch = df_axiom_links["X"].str.strip().tolist() + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) + + atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links) + atoms_polarity_batch = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) + atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) + + pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) + + truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, + df_axiom_links["Y"]) + truth_links_batch = truth_links_batch.permute(1, 0, 2) + + # Construction tensor dataset + dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, + sentences_tokens, sentences_mask) + + if validation_rate > 0.0: + train_size = int(0.9 * len(dataset)) + val_size = len(dataset) - train_size + train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + else: + validation_dataloader = None + train_dataset = dataset + + training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False) + print("End preprocess Data") + return training_dataloader, validation_dataloader + + def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding): + r""" + Args: + batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories + batch_atoms : atoms tok + batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities + batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + Returns: + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities + """ + # repeat embedding word for each atom in word with a +1 for sep + sents_embedding_repeat = pad_sequence( + [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0) + for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0) + + # atoms emebedding + src_key_padding_mask = torch.eq(batch_atoms, self.padding_id) + src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) + atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom) + atoms_embedding = self.position_encoder(atoms_embedding) + atoms_embedding = atoms_embedding.permute(1, 0, 2) + atoms_embedding = self.transformer(atoms_embedding, src_mask, + src_key_padding_mask=src_key_padding_mask) + atoms_embedding = atoms_embedding.permute(1, 0, 2) + + # cat + atoms_sentences_encoding = torch.cat([sents_embedding_repeat, atoms_embedding], dim=2) + atoms_encoding = self.linker_encoder(atoms_sentences_encoding) + + # linking per atom type + batch_size, atom_vocab_size, _ = batch_pos_idx.shape + link_weights = torch.zeros(atom_vocab_size, batch_size, self.max_atoms_in_one_type // 2, + self.max_atoms_in_one_type // 2, device=self.device) + for atom_type in list(atom_map_redux.keys()): + pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type) + neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type) + + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights[self.atom_map_redux[atom_type]] = sinkhorn(weights, iters=self.sinkhorn_iters) + + return F.log_softmax(link_weights, dim=3) + + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, + batch_size=32, checkpoint=True, tensorboard=False): + r""" + Args: + df_axiom_links : pandas dataFrame containing the atoms anoted with _i + validation_rate : float + epochs : int + batch_size : int + checkpoint : boolean + tensorboard : boolean + Returns: + Final accuracy and final loss + """ + training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, + validation_rate) + if checkpoint or tensorboard: + checkpoint_dir, writer = output_create_dir() + + for epoch_i in range(epochs): + print("") + print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) + print('Training...') + avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) + + print("") + print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') + print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') + + if validation_rate > 0.0: + loss_test, accuracy_test = self.eval_epoch(validation_dataloader) + print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') + + if checkpoint: + self.__checkpoint_save( + path=os.path.join("Output", 'linker.pt')) + + if tensorboard: + writer.add_scalars(f'Accuracy', { + 'Train': avg_accuracy_train}, epoch_i) + writer.add_scalars(f'Loss', { + 'Train': avg_train_loss}, epoch_i) + if validation_rate > 0.0: + writer.add_scalars(f'Accuracy', { + 'Validation': accuracy_test}, epoch_i) + writer.add_scalars(f'Loss', { + 'Validation': loss_test}, epoch_i) + + print('\n') + + def train_epoch(self, training_dataloader): + r""" Train epoch + + Args: + training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks + Returns: + accuracy on validation set + loss on train set + """ + self.train() + + # Reset the total loss for this epoch. + epoch_loss = 0 + accuracy_train = 0 + t0 = time.time() + + # For each batch of training data... + with tqdm(training_dataloader, unit="batch") as tepoch: + for batch in tepoch: + # Unpack this training batch from our dataloader + batch_num_atoms = batch[0].to(self.device) + batch_atoms_tok = batch[1].to(self.device) + batch_pos_idx = batch[2].to(self.device) + batch_neg_idx = batch[3].to(self.device) + batch_true_links = batch[4].to(self.device) + batch_sentences_tokens = batch[5].to(self.device) + batch_sentences_mask = batch[6].to(self.device) + + self.optimizer.zero_grad() + + # get sentence embedding from BERT which is already trained + output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) + + # Run the Linker on the atoms + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, + output['word_embeding']) + + linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + # Perform a backward pass to calculate the gradients. + epoch_loss += float(linker_loss) + linker_loss.backward() + + # This is to help prevent the "exploding gradients" problem. + # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2) + + # Update parameters and take a step using the computed gradient. + self.optimizer.step() + + pred_axiom_links = torch.argmax(logits_predictions, dim=3) + accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links) + + self.scheduler.step() + + # Measure how long this epoch took. + training_time = format_time(time.time() - t0) + avg_train_loss = epoch_loss / len(training_dataloader) + avg_accuracy_train = accuracy_train / len(training_dataloader) + + return avg_train_loss, avg_accuracy_train, training_time + + def eval_batch(self, batch): + batch_num_atoms = batch[0].to(self.device) + batch_atoms_tok = batch[1].to(self.device) + batch_pos_idx = batch[2].to(self.device) + batch_neg_idx = batch[3].to(self.device) + batch_true_links = batch[4].to(self.device) + batch_sentences_tokens = batch[5].to(self.device) + batch_sentences_mask = batch[6].to(self.device) + + output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) + + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[ + 'word_embeding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type + axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type + + print('\n') + print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) + print("Les prédictions : ", axiom_links_pred[2][1][:100]) + print('\n') + + accuracy = measure_accuracy(batch_true_links, axiom_links_pred) + loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + + return loss, accuracy + + def eval_epoch(self, dataloader): + r"""Average the evaluation of all the batch. + + Args: + dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols + """ + self.eval() + accuracy_average = 0 + loss_average = 0 + with torch.no_grad(): + for step, batch in enumerate(dataloader): + loss, accuracy = self.eval_batch(batch) + accuracy_average += accuracy + loss_average += float(loss) + + return loss_average / len(dataloader), accuracy_average / len(dataloader) + + def predict_with_categories(self, sentence, categories): + r""" Predict the links from a sentence and its categories + + Args : + sentence : list of words composing the sentence + categories : list of categories (tags) of each word + + Return : + links : links prediction + """ + self.eval() + with torch.no_grad(): + self.cpu() + self.device = torch.device("cpu") + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence]) + nb_sentence, len_sentence = sentences_tokens.shape + + atoms = get_atoms_batch([categories]) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms) + + polarities = find_pos_neg_idexes([categories]) + polarities = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) + + num_atoms_per_word = get_num_atoms_batch([categories], len_sentence) + + pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) + + output = self.Supertagger.forward(sentences_tokens, sentences_mask) + + logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding']) + axiom_links_pred = torch.argmax(logits_predictions, dim=3) + + return axiom_links_pred + + def predict_without_categories(self, sentence): + r""" Predict the links from a sentence + + Args : + sentence : list of words composing the sentence + + Return : + categories : the supertags predicted + links : links prediction + """ + self.eval() + with torch.no_grad(): + self.cpu() + self.device = torch.device("cpu") + sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence]) + nb_sentence, len_sentence = sentences_tokens.shape + + hidden_state, categories = self.Supertagger.predict(sentence) + + output = self.Supertagger.forward(sentences_tokens, sentences_mask) + atoms = get_atoms_batch(categories) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms) + + polarities = find_pos_neg_idexes(categories) + polarities = pad_sequence( + [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], + max_len=self.max_atoms_in_sentence, padding_value=0) + + num_atoms_per_word = get_num_atoms_batch(categories, len_sentence) + + pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type) + neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type) + + logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding']) + axiom_links_pred = torch.argmax(logits_predictions, dim=3) + + return categories, axiom_links_pred + + def load_weights(self, model_file): + print("#" * 15) + try: + params = torch.load(model_file, map_location=self.device) + self.atom_encoder.load_state_dict(params['atom_encoder']) + self.position_encoder.load_state_dict(params['position_encoder']) + self.transformer.load_state_dict(params['transformer']) + self.linker_encoder.load_state_dict(params['linker_encoder']) + self.pos_transformation.load_state_dict(params['pos_transformation']) + self.neg_transformation.load_state_dict(params['neg_transformation']) + self.cross_entropy_loss.load_state_dict(params['cross_entropy_loss']) + self.optimizer.load_state_dict(params['optimizer']) + print("\n The loading checkpoint was successful ! \n") + except Exception as e: + print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) + raise e + print("#" * 15) + + def __checkpoint_save(self, path='/linker.pt'): + """ + @param path: + """ + self.cpu() + + torch.save({ + 'atom_encoder': self.atom_encoder.state_dict(), + 'position_encoder': self.position_encoder, + 'transformer': self.transformer.state_dict(), + 'linker_encoder': self.linker_encoder.state_dict(), + 'pos_transformation': self.pos_transformation.state_dict(), + 'neg_transformation': self.neg_transformation.state_dict(), + 'cross_entropy_loss': self.cross_entropy_loss, + 'optimizer': self.optimizer, + }, path) + self.to(self.device) + + def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type): + """ + :param bsd_tensor: + Tensor of shape batch size \times sequence length \times feature dimensionality. + :param positional_ids: + A List of batch_size elements, each being a List of num_atoms LongTensors. + Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b. + :param atom_type: + :return: + """ + + return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device) + if atom != -1 else torch.zeros(self.dim_cat_out, device=self.device) + for atom in sentence]) + for i, sentence in enumerate(positional_ids[:, self.atom_map_redux[atom_type], :])]) diff --git a/Linker/PositionalEncoding.py b/Linker/PositionalEncoding.py new file mode 100644 index 0000000000000000000000000000000000000000..19e1b96c0bd17b9867d9d24bda52a619e7559e4e --- /dev/null +++ b/Linker/PositionalEncoding.py @@ -0,0 +1,25 @@ +import torch +from torch import nn +import math + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + """ + Args: + x: Tensor, shape [batch_size, seq_len, mbedding_dim] + """ + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) diff --git a/Linker/Sinkhorn.py b/Linker/Sinkhorn.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf9b45607800c1f35efa98801c86e3326726a19 --- /dev/null +++ b/Linker/Sinkhorn.py @@ -0,0 +1,16 @@ +from torch import logsumexp + + +def norm(x, dim): + return x - logsumexp(x, dim=dim, keepdim=True) + + +def sinkhorn_step(x): + return norm(norm(x, dim=1), dim=2) + + +def sinkhorn_fn_no_exp(x, tau=1, iters=3): + x = x / tau + for _ in range(iters): + x = sinkhorn_step(x) + return x diff --git a/Linker/__init__.py b/Linker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0983f0bce2a67ac9fea8478389d3fb706ce820a1 --- /dev/null +++ b/Linker/__init__.py @@ -0,0 +1,5 @@ +from .Linker import Linker +from .atom_map import atom_map +from .AtomTokenizer import AtomTokenizer +from .PositionalEncoding import PositionalEncoding +from .Sinkhorn import * \ No newline at end of file diff --git a/Linker/atom_map.py b/Linker/atom_map.py new file mode 100644 index 0000000000000000000000000000000000000000..0df2646a03e4a228eb9223a3eb5f167c4de2ca14 --- /dev/null +++ b/Linker/atom_map.py @@ -0,0 +1,30 @@ +atom_map = \ + {'cl_r': 0, + "pp": 1, + 'n': 2, + 's_ppres': 3, + 's_whq': 4, + 's_q': 5, + 'np': 6, + 's_inf': 7, + 's_pass': 8, + 'pp_a': 9, + 'pp_par': 10, + 'pp_de': 11, + 'cl_y': 12, + 'txt': 13, + 's': 14, + 's_ppart': 15, + "[SEP]":16, + '[PAD]': 17 + } + +atom_map_redux = { + 'cl_r': 0, + 'pp': 1, + 'n': 2, + 'np': 3, + 'cl_y': 4, + 'txt': 5, + 's': 6 +} diff --git a/Linker/eval.py b/Linker/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..b252e5e5b07fbfe1dd8cacf7e1486b0b5850c34b --- /dev/null +++ b/Linker/eval.py @@ -0,0 +1,34 @@ +import torch +from torch.nn import Module +from torch.nn.functional import nll_loss +from Linker.atom_map import atom_map, atom_map_redux + + +class SinkhornLoss(Module): + r""" + Loss for the linker + """ + def __init__(self): + super(SinkhornLoss, self).__init__() + + def forward(self, predictions, truths): + return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) + for link, perm in zip(predictions, truths.permute(1, 0, 2))) + + +def measure_accuracy(batch_true_links, axiom_links_pred): + r""" + batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms + axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms + """ + padding = -1 + batch_true_links = batch_true_links.permute(1, 0, 2) + correct_links = torch.ones(axiom_links_pred.size()) + correct_links[axiom_links_pred != batch_true_links] = 0 + correct_links[batch_true_links == padding] = 1 + num_correct_links = correct_links.sum().item() + num_masked_atoms = len(batch_true_links[batch_true_links == padding]) + + # diviser par nombre de links + return (num_correct_links - num_masked_atoms) / ( + axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms) diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py new file mode 100644 index 0000000000000000000000000000000000000000..15b37f39d0abfda2f09b50e26de7b744ba0796b9 --- /dev/null +++ b/Linker/utils_linker.py @@ -0,0 +1,382 @@ +import re + +import pandas as pd +import regex +import torch +from torch.nn import Sequential, Linear, Dropout, GELU +from torch.nn import Module + +from Linker.atom_map import atom_map, atom_map_redux +from utils import pad_sequence + + +class FFN(Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1, d_out=None): + super(FFN, self).__init__() + self.ffn = Sequential( + Linear(d_model, d_ff, bias=False), + GELU(), + Dropout(dropout), + Linear(d_ff, d_out if d_out is not None else d_model, bias=False) + ) + + def forward(self, x): + return self.ffn(x) + + +################################ Regex ######################################## +regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' +regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' + + +# region get true axiom links +def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): + r""" + Args: + max_atoms_in_one_type : configuration + atoms_polarity : (batch_size, max_atoms_in_sentence) + batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms + Returns: + batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + """ + atoms_batch = get_atoms_links_batch(batch_axiom_links) + linking_plus_to_minus_all_types = [] + for atom_type in list(atom_map_redux.keys()): + # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity + l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] + l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] + + linking_plus_to_minus = pad_sequence( + [torch.as_tensor( + [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 + for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long) + for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, + padding_value=-1) + + linking_plus_to_minus_all_types.append(linking_plus_to_minus) + + return torch.stack(linking_plus_to_minus_all_types) + + +def category_to_atoms_axiom_links(category, categories_to_atoms): + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + categories_to_atoms : recursive list + Returns : + List of atoms inside the category in prefix order + """ + res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()] + if category.startswith("GOAL:"): + word, cat = category.split(':') + return category_to_atoms_axiom_links(cat, categories_to_atoms) + elif True in res: + return [category] + else: + category_cut = regex.match(regex_categories_axiom_links, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + for cat in category_cut: + categories_to_atoms += category_to_atoms_axiom_links(cat, []) + return categories_to_atoms + + +def get_atoms_links_batch(category_batch): + r""" + Args: + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns : + (batch_size, max_atoms_in_sentence) flattened categories in prefix order + """ + batch = [] + for sentence in category_batch: + categories_to_atoms = [] + for category in sentence: + if category != "let" and not category.startswith("GOAL:"): + categories_to_atoms += category_to_atoms_axiom_links(category, []) + categories_to_atoms.append("[SEP]") + elif category.startswith("GOAL:"): + categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms + batch.append(categories_to_atoms) + return batch + + +print("test to create links ", + get_axiom_links(20, torch.stack([torch.as_tensor( + [True, False, True, False, False, False, True, False, True, False, + False, True, False, False, False, True, False, False, True, False, + True, False, False, True, False, False, False, False, False, False])]), + [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', + 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) + + +# endregion + +# region get atoms in sentence + +def category_to_atoms(category, categories_to_atoms): + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + categories_to_atoms : recursive list + Returns: + List of atoms inside the category in prefix order + """ + res = [(category == atom_type) for atom_type in atom_map.keys()] + if category.startswith("GOAL:"): + word, cat = category.split(':') + return category_to_atoms(cat, categories_to_atoms) + elif True in res: + return [category] + else: + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + for cat in category_cut: + categories_to_atoms += category_to_atoms(cat, []) + return categories_to_atoms + + +def get_atoms_batch(category_batch): + r""" + Args: + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns: + (batch_size, max_atoms_in_sentence) flattened categories in prefix order + """ + batch = [] + for sentence in category_batch: + categories_to_atoms = [] + for category in sentence: + if category != "let": + categories_to_atoms += category_to_atoms(category, []) + categories_to_atoms.append("[SEP]") + batch.append(categories_to_atoms) + return batch + + +print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']", + get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']])) + + +# endregion + +# region calculate num atoms per category + +def category_to_num_atoms(category, categories_to_atoms): + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + categories_to_atoms : recursive int + Returns: + List of atoms inside the category in prefix order + """ + res = [(category == atom_type) for atom_type in atom_map.keys()] + if category.startswith("GOAL:"): + word, cat = category.split(':') + return category_to_num_atoms(cat, 0) + elif category == "let": + return 0 + elif True in res: + return 1 + else: + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + for cat in category_cut: + categories_to_atoms += category_to_num_atoms(cat, 0) + return categories_to_atoms + + +def get_num_atoms_batch(category_batch, max_len_sentence): + r""" + Args: + category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + max_len_sentence : max_len_sentence parameter + Returns: + (batch_size, max_atoms_in_sentence) flattened categories in prefix order + """ + batch = [] + for sentence in category_batch: + num_atoms_sentence = [0] + for category in sentence: + num_atoms_in_word = category_to_num_atoms(category, 0) + # add 1 because for word we have SEP at the end + if category != "let": + num_atoms_in_word += 1 + num_atoms_sentence.append(num_atoms_in_word) + batch.append(torch.as_tensor(num_atoms_sentence)) + return pad_sequence(batch, max_len=max_len_sentence, padding_value=0) + + +print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']", + get_num_atoms_batch([["dr(0,s,np)", "let"]], 10)) + + +# endregion + +# region get polarity + +def category_to_atoms_polarity(category, polarity): + r""" + Args: + category : str of kind AtomCat | CategoryCat(dr or dl) + polarity : polarity according to recursivity + Returns: + Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes + """ + category_to_polarity = [] + res = [(category == atom_type) for atom_type in atom_map.keys()] + # mot final + if category.startswith("GOAL:"): + word, cat = category.split(':') + res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()] + if True in res: + category_to_polarity.append(True) + else: + category_to_polarity += category_to_atoms_polarity(cat, True) + # le mot a une category atomique + elif True in res: + category_to_polarity.append(not polarity) + # sinon c'est une formule longue + else: + # dr = / + if category.startswith("dr"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + left_side, right_side = category_cut[0], category_cut[1] + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, polarity) + # for the right side : change polarity for next right formula + category_to_polarity += category_to_atoms_polarity(right_side, not polarity) + + # dl = \ + elif category.startswith("dl"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + left_side, right_side = category_cut[0], category_cut[1] + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, not polarity) + # for the right side + category_to_polarity += category_to_atoms_polarity(right_side, polarity) + + # p + elif category.startswith("p"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + left_side, right_side = category_cut[0], category_cut[1] + # for the left side + category_to_polarity += category_to_atoms_polarity(left_side, not polarity) + # for the right side + category_to_polarity += category_to_atoms_polarity(right_side, polarity) + + # box + elif category.startswith("box"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity) + + # dia + elif category.startswith("dia"): + category_cut = regex.match(regex_categories, category).groups() + category_cut = [cat for cat in category_cut if cat is not None] + category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity) + + return category_to_polarity + + +def find_pos_neg_idexes(atoms_batch): + r""" + Args: + atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order + Returns: + (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order + """ + list_batch = [] + for sentence in atoms_batch: + list_atoms = [] + for category in sentence: + if category == "let": + pass + else: + for at in category_to_atoms_polarity(category, True): + list_atoms.append(at) + list_atoms.append(False) + list_batch.append(list_atoms) + return list_batch + + +print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n", + find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', + 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) + + +# endregion + +# region get atoms and polarities with GOAL + +def get_GOAL(max_len_sentence, df_axiom_links): + categories_batch = df_axiom_links["Z"] + categories_with_goal = df_axiom_links["Y"] + polarities = find_pos_neg_idexes(categories_batch) + atoms_batch = get_atoms_batch(categories_batch) + num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence) + for s_idx in range(len(atoms_batch)): + goal = categories_with_goal[s_idx][-1] + polarities_goal = category_to_atoms_polarity(goal, True) + goal = re.search(r"(\w+)_\d+", goal).groups()[0] + atoms = category_to_atoms(goal, []) + + atoms_batch[s_idx] = atoms + atoms_batch[s_idx] # + ["[SEP]"] + polarities[s_idx] = polarities_goal + polarities[s_idx] # + False + num_atoms_batch[s_idx][0] += len(atoms) # +1 + + return atoms_batch, polarities, num_atoms_batch + + +df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', + 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']], + "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', + 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', + 'GOAL:np_7']]}) +print(" test for get GOAL ", get_GOAL(10, df_axiom_links)) + + +# endregion + +# region get idx for pos and neg + +def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and + atoms_polarity_batch[s_idx][i]]) + for s_idx, sentence in enumerate(atoms_batch)], + max_len=max_atoms_in_one_type // 2, padding_value=-1) + for atom_type in list(atom_map_redux.keys())] + + return torch.stack(pos_idx).permute(1, 0, 2) + + +def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and + not atoms_polarity_batch[s_idx][i]]) + for s_idx, sentence in enumerate(atoms_batch)], + max_len=max_atoms_in_one_type // 2, padding_value=-1) + for atom_type in list(atom_map_redux.keys())] + + return torch.stack(pos_idx).permute(1, 0, 2) + + +print(" test for cut into pos neg on ['dr(0,s,np)', 's']", + get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']], + torch.as_tensor( + [[True, True, False, False, + True, False, False, False, + False, False, + False, False]]), 10)) + +# endregion \ No newline at end of file diff --git a/README.md b/README.md index d0df77978d55dd3b41298f8129d244e9f4daca7a..991b1d926d03d6721c816131db8d7526f365e387 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,18 @@ Clone the project locally. ### Libraries installation -Run the init.sh script and install the Tagger project under SuperTagger name and the Linker directory in Linker project under Linker name. +Run the following script : -Upload the tagger.pt in models. (You may need to modify 'model_tagger' in train.py.) +```bash +python3 -m venv env +source env/bin/activate +pip install -r requirements.txt -You can upload a linker model, so there is no pretraining, you just need to give it to the Proof net initialization. +mkdir Output +mkdir TensorBoard +``` + +Optional : Upload the tagger.pt and linker.pt in models. (You may need to modify 'model_tagger' in train.py.) ### Structure diff --git a/requirements.txt b/requirements.txt index ce8002c1e34cda5a75cd5330dc1e4ab8df659555..18401ec4c66ea340b50e66e30e3ac46d2e191e84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ packaging==21.3 scikit-learn==1.0.2 scipy==1.8.0 sentencepiece==0.1.96 -tensorflow==2.9.1 tensorboard==2.8.0 torch==1.11.0 tqdm==4.64.0