diff --git a/Configuration/config.ini b/Configuration/config.ini index c3ccbc2d4759512eb96cb3ae2db0bef7a6d5a939..fedb31786ee8e94c9842dda2a24209c43827183e 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -5,16 +5,17 @@ transformers = 4.16.2 symbols_vocab_size=26 atom_vocab_size=18 max_len_sentence=290 -max_atoms_in_sentence=874 +max_atoms_in_sentence=875 max_atoms_in_one_type=324 [MODEL_ENCODER] dim_encoder = 768 [MODEL_LINKER] -nhead=4 +nhead=16 dim_emb_atom = 256 -num_layers=2 +dim_feedforward_transformer = 512 +num_layers=3 dim_cat_inter=512 dim_cat_out=256 dim_intermediate_FFN=128 @@ -23,7 +24,7 @@ dropout=0.1 sinkhorn_iters=5 [MODEL_TRAINING] -batch_size=32 -epoch=25 +batch_size=16 +epoch=30 seed_val=42 learning_rate=2e-3 \ No newline at end of file diff --git a/Linker/Linker.py b/Linker/Linker.py index ee8842514e241cfe0ec47b9cf68730cebac92050..3012c7e1e58f9693d7556acbccae6611e85b2ec0 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -22,7 +22,7 @@ from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.AtomTokenizer import AtomTokenizer from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx from Supertagger import SuperTagger from utils import pad_sequence @@ -69,6 +69,7 @@ class Linker(Module): # Transformer self.nhead = int(Configuration.modelLinkerConfig['nhead']) self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom']) + self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer']) self.num_layers = int(Configuration.modelLinkerConfig['num_layers']) # torch cat self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out']) @@ -78,7 +79,6 @@ class Linker(Module): # sinkhorn self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) # settings - self.batch_size = int(Configuration.modelTrainingConfig['batch_size']) self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) @@ -95,11 +95,13 @@ class Linker(Module): # Atoms embedding self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atom_map_redux = atom_map_redux + self.padding_id = atom_map["[PAD]"] self.sub_atoms_type_list = list(atom_map_redux.keys()) - self.atom_encoder = Embedding(self.max_atoms_in_sentence, self.dim_emb_atom, padding_idx=atom_map["[PAD]"]) + self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id) self.atom_encoder.weight.data.uniform_(-0.1, 0.1) self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence) - encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead) + encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead, + dim_feedforward=self.dim_feedforward_transformer, dropout=0.1) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers) # Concatenation with word embedding @@ -146,8 +148,8 @@ class Linker(Module): num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) - pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) - neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) + pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) @@ -170,12 +172,11 @@ class Linker(Module): print("End preprocess Data") return training_dataloader, validation_dataloader - def forward(self, batch_num_atoms_per_word, batch_atoms, src_mask, batch_pos_idx, batch_neg_idx, sents_embedding): + def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding): r""" Args: batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories batch_atoms : atoms tok - src_mask : atoms mask batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context @@ -187,10 +188,14 @@ class Linker(Module): [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0) for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0) + # atoms emebedding + src_key_padding_mask = torch.eq(batch_atoms, self.padding_id) + src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom) atoms_embedding = self.position_encoder(atoms_embedding) atoms_embedding = atoms_embedding.permute(1, 0, 2) - atoms_embedding = self.transformer(atoms_embedding, src_mask) + atoms_embedding = self.transformer(atoms_embedding, src_mask, + src_key_padding_mask=src_key_padding_mask) atoms_embedding = atoms_embedding.permute(1, 0, 2) # cat @@ -280,7 +285,6 @@ class Linker(Module): # For each batch of training data... with tqdm(training_dataloader, unit="batch") as tepoch: - src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) for batch in tepoch: # Unpack this training batch from our dataloader batch_num_atoms = batch[0].to(self.device) @@ -297,10 +301,10 @@ class Linker(Module): output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the Linker on the atoms - logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embeding']) - linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type) # Perform a backward pass to calculate the gradients. epoch_loss += float(linker_loss) linker_loss.backward() @@ -334,19 +338,17 @@ class Linker(Module): output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) - src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) - logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, output[ + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[ 'word_embeding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type print('\n') - print("Tokens de la phrase : ", batch_sentences_tokens[1]) print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) print("Les prédictions : ", axiom_links_pred[2][1][:100]) print('\n') accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type) - loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type) return loss, accuracy diff --git a/Linker/eval.py b/Linker/eval.py index 2c8c578687bec168d04fd1ee81e0357ec2f1dac2..05c096639ee2d12f9b6fa38f44833067b4169440 100644 --- a/Linker/eval.py +++ b/Linker/eval.py @@ -1,14 +1,15 @@ import torch from torch.nn import Module from torch.nn.functional import nll_loss +from Linker.atom_map import atom_map, atom_map_redux class SinkhornLoss(Module): def __init__(self): super(SinkhornLoss, self).__init__() - def forward(self, predictions, truths): - return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') + def forward(self, predictions, truths, max_atoms_in_one_type): + return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) for link, perm in zip(predictions, truths.permute(1, 0, 2))) @@ -17,7 +18,7 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type): batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms """ - padding = max_atoms_in_one_type // 2 - 1 + padding = -1 batch_true_links = batch_true_links.permute(1, 0, 2) correct_links = torch.ones(axiom_links_pred.size()) correct_links[axiom_links_pred != batch_true_links] = 0 diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index f2f418ff079dc8a17d5fb18094535566c75d58ec..8bb55d1673f82ce8863b8177ce537e26249d52cb 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -45,18 +45,18 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): for atom_type in list(atom_map_redux.keys()): # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in - range(len(atoms_batch))] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in - range(len(atoms_batch))] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] linking_plus_to_minus = pad_sequence( [torch.as_tensor( - [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 - 1 + [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long) for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, - padding_value=max_atoms_in_one_type // 2 - 1) + padding_value=-1) linking_plus_to_minus_all_types.append(linking_plus_to_minus) @@ -108,8 +108,12 @@ def get_atoms_links_batch(category_batch): print("test to create links ", - get_axiom_links(20, torch.stack([torch.as_tensor([False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, False, True, False, False, False, True])]), - [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) + get_axiom_links(20, torch.stack([torch.as_tensor( + [False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, + False, True, False, True, False, False, True, False, False, False, True])]), + [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', + 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) + # endregion @@ -305,8 +309,10 @@ def find_pos_neg_idexes(atoms_batch): return list_batch -print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']", - find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) +print( + " test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']", + find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', + 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) # endregion @@ -349,11 +355,12 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", # region get idx for pos and neg -def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): +def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): atoms_batch_for_polarities = list( map(lambda sentence: sentence.split(" "), atoms_batch)) - pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", + atoms_batch_for_polarities[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_for_polarities)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -362,11 +369,12 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): return torch.stack(pos_idx).permute(1, 0, 2) -def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): +def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): atoms_batch_for_polarities = list( map(lambda sentence: sentence.split(" "), atoms_batch)) - pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and not + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", + atoms_batch_for_polarities[s_idx][i])) and not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_for_polarities)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -380,6 +388,6 @@ print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg [[False, True, False, False, False, False, True, True, False, True, - False, False]]), 10)) + False, False]]), 10, 50)) # endregion diff --git a/bash_GPU.sh b/bash_GPU.sh index 500c7326f767af683a3c0a31e2e0026f3f2e74d3..99692203e0a64519649244caa479801da0500a2a 100644 --- a/bash_GPU.sh +++ b/bash_GPU.sh @@ -1,6 +1,6 @@ #!/bin/sh #SBATCH --job-name=Deepgrail_Linker -#SBATCH --partition=GPUNodes +#SBATCH --partition=RTX6000Node #SBATCH --gres=gpu:1 #SBATCH --mem=32000 #SBATCH --gres-flags=enforce-binding diff --git a/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 b/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 deleted file mode 100644 index 09582fa3c93ec31068f014c18ba3bd65ffe59935..0000000000000000000000000000000000000000 Binary files a/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 and /dev/null differ diff --git a/logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 b/logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 new file mode 100644 index 0000000000000000000000000000000000000000..836da89cfb7f563cf1ff1108b2eed3dd66ef9ff8 Binary files /dev/null and b/logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 differ diff --git a/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 b/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 deleted file mode 100644 index 5442af218aa6bb356f2080bc5dd84cdf9ecfb6cc..0000000000000000000000000000000000000000 Binary files a/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 and /dev/null differ diff --git a/logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 b/logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 new file mode 100644 index 0000000000000000000000000000000000000000..1945341e16f8ba0d0acfb9bd7fa891e2b53b9552 Binary files /dev/null and b/logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 differ diff --git a/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 b/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 deleted file mode 100644 index c9b5e3bf3131b8fa559236aa17551dc3f0055839..0000000000000000000000000000000000000000 Binary files a/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 and /dev/null differ diff --git a/logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 b/logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 new file mode 100644 index 0000000000000000000000000000000000000000..89acad93d838280f7f49a8320892c5c3d1b5440c Binary files /dev/null and b/logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 differ diff --git a/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 b/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 deleted file mode 100644 index c3af9b258413271b78351acffec0f41824f7cec3..0000000000000000000000000000000000000000 Binary files a/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 and /dev/null differ diff --git a/logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 b/logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 new file mode 100644 index 0000000000000000000000000000000000000000..27f3fb8fb1adb2f8cb3f308002e5aceb72ce8ff3 Binary files /dev/null and b/logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 differ diff --git a/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 b/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 deleted file mode 100644 index d8f3c0aa95cc484017700deac105cacffc3cb025..0000000000000000000000000000000000000000 Binary files a/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 and /dev/null differ diff --git a/logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 b/logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 new file mode 100644 index 0000000000000000000000000000000000000000..b701efbbdba2cec059713e10b820c830a690ea70 Binary files /dev/null and b/logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 differ diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf3936a593eaf3ceecc042167694b57caec06d2 --- /dev/null +++ b/train.py @@ -0,0 +1,17 @@ +import torch +from Configuration import Configuration +from Linker import * +from utils import read_csv_pgbar + +torch.cuda.empty_cache() +batch_size = int(Configuration.modelTrainingConfig['batch_size']) +nb_sentences = batch_size * 800 +epochs = int(Configuration.modelTrainingConfig['epoch']) + +file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' +df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) + +print("Linker") +linker = Linker("models/flaubert_super_98_V2_50e.pt") +print("\nLinker Training\n") +linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True) \ No newline at end of file