From a539008d26b6bff4ae5b3a51ddd03df1517342b3 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Wed, 22 Jun 2022 17:00:51 +0200 Subject: [PATCH 1/6] change padding handling --- Configuration/config.ini | 5 ++- Linker/Linker.py | 32 ++++++------- Linker/eval.py | 7 +-- Linker/utils_linker.py | 42 +++++++++++------- bash_GPU.sh | 2 +- ...tfevents.1655740922.co2-slurm-ng04.19806.1 | Bin 1142 -> 0 bytes ...tfevents.1655740922.co2-slurm-ng04.19806.3 | Bin 1142 -> 0 bytes ...tfevents.1655740922.co2-slurm-ng04.19806.2 | Bin 1046 -> 0 bytes ...tfevents.1655740922.co2-slurm-ng04.19806.4 | Bin 1046 -> 0 bytes ...tfevents.1655739317.co2-slurm-ng04.19806.0 | Bin 40 -> 0 bytes 10 files changed, 50 insertions(+), 38 deletions(-) delete mode 100644 logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 delete mode 100644 logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 delete mode 100644 logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 delete mode 100644 logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 delete mode 100644 logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 diff --git a/Configuration/config.ini b/Configuration/config.ini index c3ccbc2..61872f4 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -5,15 +5,16 @@ transformers = 4.16.2 symbols_vocab_size=26 atom_vocab_size=18 max_len_sentence=290 -max_atoms_in_sentence=874 +max_atoms_in_sentence=875 max_atoms_in_one_type=324 [MODEL_ENCODER] dim_encoder = 768 [MODEL_LINKER] -nhead=4 +nhead=8 dim_emb_atom = 256 +dim_feedforward_transformer = 768 num_layers=2 dim_cat_inter=512 dim_cat_out=256 diff --git a/Linker/Linker.py b/Linker/Linker.py index ee88425..15e775c 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -69,6 +69,7 @@ class Linker(Module): # Transformer self.nhead = int(Configuration.modelLinkerConfig['nhead']) self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom']) + self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer']) self.num_layers = int(Configuration.modelLinkerConfig['num_layers']) # torch cat self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out']) @@ -78,7 +79,6 @@ class Linker(Module): # sinkhorn self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) # settings - self.batch_size = int(Configuration.modelTrainingConfig['batch_size']) self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) @@ -95,11 +95,13 @@ class Linker(Module): # Atoms embedding self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atom_map_redux = atom_map_redux + self.padding_id = atom_map["[PAD]"] self.sub_atoms_type_list = list(atom_map_redux.keys()) - self.atom_encoder = Embedding(self.max_atoms_in_sentence, self.dim_emb_atom, padding_idx=atom_map["[PAD]"]) + self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id) self.atom_encoder.weight.data.uniform_(-0.1, 0.1) self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence) - encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead) + encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead, + dim_feedforward=self.dim_feedforward_transformer, dropout=0.1) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers) # Concatenation with word embedding @@ -146,8 +148,8 @@ class Linker(Module): num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) - pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) - neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type) + pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) + neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) @@ -170,12 +172,11 @@ class Linker(Module): print("End preprocess Data") return training_dataloader, validation_dataloader - def forward(self, batch_num_atoms_per_word, batch_atoms, src_mask, batch_pos_idx, batch_neg_idx, sents_embedding): + def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding): r""" Args: batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories batch_atoms : atoms tok - src_mask : atoms mask batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context @@ -187,10 +188,14 @@ class Linker(Module): [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0) for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0) + # atoms emebedding + src_key_padding_mask = torch.eq(batch_atoms, self.padding_id) + src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom) atoms_embedding = self.position_encoder(atoms_embedding) atoms_embedding = atoms_embedding.permute(1, 0, 2) - atoms_embedding = self.transformer(atoms_embedding, src_mask) + atoms_embedding = self.transformer(atoms_embedding, src_mask, + src_key_padding_mask=src_key_padding_mask) atoms_embedding = atoms_embedding.permute(1, 0, 2) # cat @@ -280,7 +285,6 @@ class Linker(Module): # For each batch of training data... with tqdm(training_dataloader, unit="batch") as tepoch: - src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) for batch in tepoch: # Unpack this training batch from our dataloader batch_num_atoms = batch[0].to(self.device) @@ -297,10 +301,10 @@ class Linker(Module): output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the Linker on the atoms - logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output['word_embeding']) - linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type) # Perform a backward pass to calculate the gradients. epoch_loss += float(linker_loss) linker_loss.backward() @@ -334,19 +338,17 @@ class Linker(Module): output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) - src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device) - logits_predictions = self(batch_num_atoms, batch_atoms_tok, src_mask, batch_pos_idx, batch_neg_idx, output[ + logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[ 'word_embeding']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type print('\n') - print("Tokens de la phrase : ", batch_sentences_tokens[1]) print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) print("Les prédictions : ", axiom_links_pred[2][1][:100]) print('\n') accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type) - loss = self.cross_entropy_loss(logits_predictions, batch_true_links) + loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type) return loss, accuracy diff --git a/Linker/eval.py b/Linker/eval.py index 2c8c578..05c0966 100644 --- a/Linker/eval.py +++ b/Linker/eval.py @@ -1,14 +1,15 @@ import torch from torch.nn import Module from torch.nn.functional import nll_loss +from Linker.atom_map import atom_map, atom_map_redux class SinkhornLoss(Module): def __init__(self): super(SinkhornLoss, self).__init__() - def forward(self, predictions, truths): - return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') + def forward(self, predictions, truths, max_atoms_in_one_type): + return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) for link, perm in zip(predictions, truths.permute(1, 0, 2))) @@ -17,7 +18,7 @@ def mesure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type): batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms """ - padding = max_atoms_in_one_type // 2 - 1 + padding = -1 batch_true_links = batch_true_links.permute(1, 0, 2) correct_links = torch.ones(axiom_links_pred.size()) correct_links[axiom_links_pred != batch_true_links] = 0 diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index f2f418f..8bb55d1 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -45,18 +45,18 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): for atom_type in list(atom_map_redux.keys()): # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in - range(len(atoms_batch))] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i] - and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx in - range(len(atoms_batch))] + and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx + in range(len(atoms_batch))] linking_plus_to_minus = pad_sequence( [torch.as_tensor( - [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else max_atoms_in_one_type // 2 - 1 + [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long) for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, - padding_value=max_atoms_in_one_type // 2 - 1) + padding_value=-1) linking_plus_to_minus_all_types.append(linking_plus_to_minus) @@ -108,8 +108,12 @@ def get_atoms_links_batch(category_batch): print("test to create links ", - get_axiom_links(20, torch.stack([torch.as_tensor([False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, False, True, False, False, False, True])]), - [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) + get_axiom_links(20, torch.stack([torch.as_tensor( + [False, True, False, False, False, True, False, True, False, False, True, False, False, False, True, False, + False, True, False, True, False, False, True, False, False, False, True])]), + [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)', + 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']])) + # endregion @@ -305,8 +309,10 @@ def find_pos_neg_idexes(atoms_batch): return list_batch -print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']", - find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) +print( + " test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']", + find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', + 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']])) # endregion @@ -349,11 +355,12 @@ print(" test for get GOAL on ['dr(0,s,np)', 's']", get_GOAL(12, [["dr(0,s,np)", # region get idx for pos and neg -def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): +def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): atoms_batch_for_polarities = list( map(lambda sentence: sentence.split(" "), atoms_batch)) - pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", + atoms_batch_for_polarities[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_for_polarities)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -362,11 +369,12 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): return torch.stack(pos_idx).permute(1, 0, 2) -def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type): +def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence): atoms_batch_for_polarities = list( map(lambda sentence: sentence.split(" "), atoms_batch)) - pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool( - re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch_for_polarities[s_idx][i])) and not + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if + bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", + atoms_batch_for_polarities[s_idx][i])) and not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_for_polarities)], max_len=max_atoms_in_one_type // 2, padding_value=-1) @@ -380,6 +388,6 @@ print(" test for cut into pos neg on ['s np [SEP] s [SEP] np s s n n']", get_neg [[False, True, False, False, False, False, True, True, False, True, - False, False]]), 10)) + False, False]]), 10, 50)) # endregion diff --git a/bash_GPU.sh b/bash_GPU.sh index 500c732..9969220 100644 --- a/bash_GPU.sh +++ b/bash_GPU.sh @@ -1,6 +1,6 @@ #!/bin/sh #SBATCH --job-name=Deepgrail_Linker -#SBATCH --partition=GPUNodes +#SBATCH --partition=RTX6000Node #SBATCH --gres=gpu:1 #SBATCH --mem=32000 #SBATCH --gres-flags=enforce-binding diff --git a/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 b/logs/logs/Accuracy_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.1 deleted file mode 100644 index 09582fa3c93ec31068f014c18ba3bd65ffe59935..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1142 zcmb1OfPlsI-b$SI_ZQTut-0YS#hX-=n3<>NT9%quVr6uG_skX<h%$x~2dBz%a?gXP z(-P$3=i+coPA)A<Os*7let6n$$@;ZZ<)CVB-Vrk5RC7<}*I0AIQGgLmb$`xxyCbaE zqfu0!<C!{N6Qr66P4y2$5qnoov(G51o!7M+XoFNUqp7Y|P_oaf&&@?qeOztI0UeNP z7Btlcue9v5{yTG`sE+mQ4$uRsW<^tN&7o%>t9{%VMRkdU^8<a5YBn_0r-O{_CAgmP zqNvWAxI5nvq?#Q~_4H@P_V=4*-=L_jYrFT~2&9?=P4(I^GyBAknY&O_KjrD{Hvy^U zL{ps+Y;OPUjNu;?)y$5X@@62_TxhCoFI(7ObC@ZDqI&sl=Jn<v)!b;R8G|kDZy2}! zL{Y7>Z?c^wNHq_d>N}2B_FKN+nv0@ZC->i3E0AhlG}UpDR`#DGSD~hpv@I)QZ9uB| z&{S{RWo4h&s}zA^w!k`;x3(bF{Aj90k6GC-&HK)QqWW6UtQvceY5_FWhb*n_g~VL% zqo{tcZacRlNVOoE>QD)5`wP>Vdr?#;$8gPd0;v{4Q?21?ZO_Ve=Oc>h>pbpSE+Eyy zXsSz{tnI^@x6edT-MsnoPFIj>5j53WCD!(4JZIQZRBuQU^l}HO7DZEi%+A{W@^AYW uD5@u>HQe+7sTM<1Z60N9zsRCn5JmOkPaiY9K&r*jR8O63ZEtdmeJ=pRTuC(m diff --git a/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 b/logs/logs/Accuracy_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.3 deleted file mode 100644 index 5442af218aa6bb356f2080bc5dd84cdf9ecfb6cc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1142 zcmb1OfPlsI-b$Q@H!rGFTXVxviZ`h!F*8rkwJbHS#LCE<>4c^XL>a@0gHvTW>!(4~ zX$f-ib8$E(Czlo_CRd8Kcz>|7ShLJR4yyL%9U&vmYTsmjjWstM1sKs(*UaIv&(!bn zLs88oJaxV%NHr6hY7qu?`{hRKT2NHCZD}{q2B~I7Q+@1{o_$Pb)jt%~HM&y{=zvtS zpsBu{WorLExUC6AwNXHKfF4LSE1K#_7fkJs9qBxZqS{N&`GG!2H5;1hpY9g+ERO|$ zqNtueeRsYgNHsf}>bE;A?PG-Ybfc&a@45Hi2&9?=P4zzxD|>Ore*!3~mkM?En}Aev zqNy%>W@X=MG0y=-b+fysyctL}7n*9>PuBM5)!8?rs6P9UdA&JEH8+~-Be^#A8`dmg zKvCUyWU`$lNHq_d>a){q?Cpw5vQbp~7XLeI1yaq6rh5Jz8~Yb!Jq;+TGj^?rwE?N- zLsQ-5XKT;-OY8@V>L*)R-r9mx^P{O=C}?Z%Hfv`kit1I7vuf-?ss+$ge>AqWPYa!V z5=C{)*6rMmAk~6ss(%>S+UKv=*G5s@m&7&O38Y#GP4(k+Tl?8YGasO+{w3tD<pNSI zjHddBr>(umPD|97x7c-grz=Rc2%743mA3ZB%N~29nEgIO(90d9S`<w+yT7e{&(poN uD5{U-G~Dz6sTM<1Z4zv2zdPJyH;QVZpC2>4K&r*jRBsQowf}RX>M#HqYfmiz diff --git a/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 b/logs/logs/Loss_Train/events.out.tfevents.1655740922.co2-slurm-ng04.19806.2 deleted file mode 100644 index c9b5e3bf3131b8fa559236aa17551dc3f0055839..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1046 zcmb1OfPlsI-b$Qp35)8~*4%KE;!P?_%*@ksElbTSu`=>&5ZHsF&QWS1Se+Ix7dID+ zPkwQ+sD&lB)AMQj?@K|9V-T8sLzlD3Cz)Sk%?(EZMkJM&*twmOqqk?Fs5~e*b-pG@ zB@>d$_7_}ErB|eDQB+Rf+-{%^Qpt>@QvDj2Q|5K)6cm*PI#UklfK;*|shoX;%W1!5 z{0tP8PyD(A^gt?EkyOsx$K@1uWqlHg%2-+F2l^nDY)C3+Z{c!^%=njwqB3XN?tDX# zN_HfbJ2!AS%`gkoLs1#jeeb^!NF@i7%7d%8oV2Gpx}vB&E!f#_0#eC|q%v#;ms5ic z!(tSb)^3{eW+0VZNGe|};c_aUdhZd6$}<m`*PDY>awDlMS<K}m^CIjfic0^(lkF@) zDtVApb}Zy_YTqbSgQ7C9=-*i@kV;-8m3tR(Ii-H;&OuSBwR1(R4M-&)lFH)wTuvIE zsjpE~7H?*GYYS4zkEF7BK9^IE>>&XZmD3_-)!2hn3LvRGIFHMzhOa&mMP<U4?c9za zm4ZkrTjp^&ZP%Ri4n^hEM6TIRAeBN$DjDW+IepjS=tWV<CgiT=0#YfAq*8Dmm(%G3 z8&y$M>g>F{(-ovr1WD!lxm-@aw_2j6`BmwHUhW{3qDU%*=W#iiZQJ$(#muSM4L3bN mD#eggTF&Eg%J7^kf}-;3_m3G~AeG`sDlgCFa(eqHR~7)^FdM@F diff --git a/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 b/logs/logs/Loss_Validation/events.out.tfevents.1655740922.co2-slurm-ng04.19806.4 deleted file mode 100644 index c3af9b258413271b78351acffec0f41824f7cec3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1046 zcmb1OfPlsI-b$SFnHJZnt-0YS#hX-=n3<>NT9%quVrBG>^W+>9b@Hzlfz@g8a&dF9 z_~aKCi|%ISc52-IHb)9#9D~s88@ilUe#!hAYi>9SFe0gp`o!gQgq8Ufipm?pQ|D`f zR5Br{G&{%T^d#(T5{k;qt?dTdAeGEWDp~e%Ib{jneSo6UPjAWr9gs>EB$XBGxtw@T zeDgt3$rjihpa)XPilp-9DlVt4w1a*qDu2m2KhOuMWJ6M^znsfyN|MbC6qOM(cIO*{ zRI($f6j;pVG_k@m7e%FP@4f#<Ae9_QDw7v*IYr9#XrrhM5$^0a0jcCfQYpQF%jqGH z=Oz@Dejb|gW+0VZNGfm4;c}WNb&(52W$z>A_2wXz+(;^S%;9oUH~H6uqB8pEWIIcc zN**MY(zCgoe6=jEp{Nup`FGX|q>>j&W$!F5r+pF+zoMu#-n}B$2BeY?N#&!NTu!$Z zPSi(Hxp^zgTU(Gyek7HvW^y@kO)q$kqOvS%R*gMKr2vx3XEV5*>=sC`KvCJYZ9BIk zNTnc>N}-utPF2RQ&!edHPv)BK1X3x4q>_0ims8*MjL#@4w+Opyxqws(BdL5fgUiYD zq^~oIO4i+%ce;X9iXf@nGlR=X?d@UIkl&pt=;aPlDT<_W?+h-dPvV(cC}uY0Hr(_8 isT4y~d1y!T8Wfc`etyjG0;v>7Qn_;mmy>znmUIBuOdHVv diff --git a/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 b/logs/logs/events.out.tfevents.1655739317.co2-slurm-ng04.19806.0 deleted file mode 100644 index d8f3c0aa95cc484017700deac105cacffc3cb025..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 40 rcmb1OfPlsI-b$SHFMQsrthwPR#hX-=n3<>NT9%quVrBF?rj{81-X;vt -- GitLab From d1c8b81386fcf1e578c6db288af0ee5f2b335990 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 23 Jun 2022 17:30:00 +0200 Subject: [PATCH 2/6] change padding handling --- Configuration/config.ini | 6 ++-- Linker/DataParallelLinker.py | 64 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 Linker/DataParallelLinker.py diff --git a/Configuration/config.ini b/Configuration/config.ini index 61872f4..b33d6df 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -4,7 +4,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 atom_vocab_size=18 -max_len_sentence=290 +max_len_sentence=83 max_atoms_in_sentence=875 max_atoms_in_one_type=324 @@ -12,10 +12,10 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=8 +nhead=16 dim_emb_atom = 256 dim_feedforward_transformer = 768 -num_layers=2 +num_layers=3 dim_cat_inter=512 dim_cat_out=256 dim_intermediate_FFN=128 diff --git a/Linker/DataParallelLinker.py b/Linker/DataParallelLinker.py new file mode 100644 index 0000000..5885845 --- /dev/null +++ b/Linker/DataParallelLinker.py @@ -0,0 +1,64 @@ +import datetime + +from torch.nn import DataParallel, Module +from Linker import * + + +class DataParallelModel(Module): + + def __init__(self): + super().__init__() + self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt")) + + def forward(self, x): + x = self.linker(x) + return x + + def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, + batch_size=32, checkpoint=True, tensorboard=False): + r""" + Args: + df_axiom_links : pandas dataFrame containing the atoms anoted with _i + validation_rate : float + epochs : int + batch_size : int + checkpoint : boolean + tensorboard : boolean + Returns: + Final accuracy and final loss + """ + training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, + validation_rate) + if checkpoint or tensorboard: + checkpoint_dir, writer = output_create_dir() + + for epoch_i in range(epochs): + print("") + print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) + print('Training...') + avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) + + print("") + print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') + print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') + + if validation_rate > 0.0: + loss_test, accuracy_test = self.eval_epoch(validation_dataloader) + print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') + + if checkpoint: + self.__checkpoint_save( + path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) + + if tensorboard: + writer.add_scalars(f'Accuracy', { + 'Train': avg_accuracy_train}, epoch_i) + writer.add_scalars(f'Loss', { + 'Train': avg_train_loss}, epoch_i) + if validation_rate > 0.0: + writer.add_scalars(f'Accuracy', { + 'Validation': accuracy_test}, epoch_i) + writer.add_scalars(f'Loss', { + 'Validation': loss_test}, epoch_i) + + print('\n') \ No newline at end of file -- GitLab From 8c037d89c1eacd2f4b226b5f4f8b45a135ba6878 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Fri, 24 Jun 2022 09:07:35 +0200 Subject: [PATCH 3/6] correction --- Configuration/config.ini | 6 ++-- Linker/DataParallelLinker.py | 64 ------------------------------------ Linker/Linker.py | 4 +-- 3 files changed, 5 insertions(+), 69 deletions(-) delete mode 100644 Linker/DataParallelLinker.py diff --git a/Configuration/config.ini b/Configuration/config.ini index b33d6df..cd5dbae 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -4,7 +4,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 atom_vocab_size=18 -max_len_sentence=83 +max_len_sentence=290 max_atoms_in_sentence=875 max_atoms_in_one_type=324 @@ -12,7 +12,7 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=16 +nhead=8 dim_emb_atom = 256 dim_feedforward_transformer = 768 num_layers=3 @@ -25,6 +25,6 @@ sinkhorn_iters=5 [MODEL_TRAINING] batch_size=32 -epoch=25 +epoch=30 seed_val=42 learning_rate=2e-3 \ No newline at end of file diff --git a/Linker/DataParallelLinker.py b/Linker/DataParallelLinker.py deleted file mode 100644 index 5885845..0000000 --- a/Linker/DataParallelLinker.py +++ /dev/null @@ -1,64 +0,0 @@ -import datetime - -from torch.nn import DataParallel, Module -from Linker import * - - -class DataParallelModel(Module): - - def __init__(self): - super().__init__() - self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt")) - - def forward(self, x): - x = self.linker(x) - return x - - def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, - batch_size=32, checkpoint=True, tensorboard=False): - r""" - Args: - df_axiom_links : pandas dataFrame containing the atoms anoted with _i - validation_rate : float - epochs : int - batch_size : int - checkpoint : boolean - tensorboard : boolean - Returns: - Final accuracy and final loss - """ - training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, - validation_rate) - if checkpoint or tensorboard: - checkpoint_dir, writer = output_create_dir() - - for epoch_i in range(epochs): - print("") - print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) - print('Training...') - avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) - - print("") - print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') - print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') - - if validation_rate > 0.0: - loss_test, accuracy_test = self.eval_epoch(validation_dataloader) - print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') - - if checkpoint: - self.__checkpoint_save( - path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) - - if tensorboard: - writer.add_scalars(f'Accuracy', { - 'Train': avg_accuracy_train}, epoch_i) - writer.add_scalars(f'Loss', { - 'Train': avg_train_loss}, epoch_i) - if validation_rate > 0.0: - writer.add_scalars(f'Accuracy', { - 'Validation': accuracy_test}, epoch_i) - writer.add_scalars(f'Loss', { - 'Validation': loss_test}, epoch_i) - - print('\n') \ No newline at end of file diff --git a/Linker/Linker.py b/Linker/Linker.py index 15e775c..3012c7e 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -22,7 +22,7 @@ from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.AtomTokenizer import AtomTokenizer from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx from Supertagger import SuperTagger from utils import pad_sequence @@ -149,7 +149,7 @@ class Linker(Module): num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) - neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"]) -- GitLab From 808c2aa64b72ec556aa28fcc66853351a98fac34 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.home> Date: Fri, 24 Jun 2022 10:18:24 +0200 Subject: [PATCH 4/6] test --- Configuration/config.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Configuration/config.ini b/Configuration/config.ini index cd5dbae..ae20d94 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -14,7 +14,7 @@ dim_encoder = 768 [MODEL_LINKER] nhead=8 dim_emb_atom = 256 -dim_feedforward_transformer = 768 +dim_feedforward_transformer = 512 num_layers=3 dim_cat_inter=512 dim_cat_out=256 -- GitLab From c58510004e59c2d710679298948c78d23c067883 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.home> Date: Fri, 24 Jun 2022 14:10:03 +0200 Subject: [PATCH 5/6] test --- Configuration/config.ini | 4 ++-- ...tfevents.1656000220.co2-slurm-ngrtx01.97600.1 | Bin 0 -> 1142 bytes ...tfevents.1656000220.co2-slurm-ngrtx01.97600.3 | Bin 0 -> 1142 bytes ...tfevents.1656000220.co2-slurm-ngrtx01.97600.2 | Bin 0 -> 1046 bytes ...tfevents.1656000220.co2-slurm-ngrtx01.97600.4 | Bin 0 -> 1046 bytes ...tfevents.1655998955.co2-slurm-ngrtx01.97600.0 | Bin 0 -> 40 bytes 6 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 create mode 100644 logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 create mode 100644 logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 create mode 100644 logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 create mode 100644 logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 diff --git a/Configuration/config.ini b/Configuration/config.ini index ae20d94..fedb317 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -12,7 +12,7 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=8 +nhead=16 dim_emb_atom = 256 dim_feedforward_transformer = 512 num_layers=3 @@ -24,7 +24,7 @@ dropout=0.1 sinkhorn_iters=5 [MODEL_TRAINING] -batch_size=32 +batch_size=16 epoch=30 seed_val=42 learning_rate=2e-3 \ No newline at end of file diff --git a/logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 b/logs/logs/Accuracy_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.1 new file mode 100644 index 0000000000000000000000000000000000000000..836da89cfb7f563cf1ff1108b2eed3dd66ef9ff8 GIT binary patch literal 1142 zcmb1OfPlsI-b$QTjU=}#uf5?Y#hX-=n3<>NT9%quVr6t=*?S!sh%$x~2dBz%iv5xR ztJ4zX;^*RUOinH>N=&X44J*29Cur|gCkIt~^Nx@a=Y>7jKdP*~;V8g}raIeLz}_`Q zq7+5-+LfQQ)j+D5&{X%AD%ii}(>;Kq`j_&)G<A?_W;E5)i!|(QtK|-%s9tomaG?fB zH4B<*=KaR@EVq;9qNq+gvhR*2NHr^(>iZ5B_WIFG`%zRcH(V#K4N}d9raF1Et-Y}3 zlOHIm+5AHSbU>=v(Nynbcd&ot-s6Fydi|#79l9XZ9B8V4@j2U{Tfw;%MfKf=`h$8P z)tqRmBUiiFpO!Q%Mp5k_b>OEyNHrIlY9?nl`=Ikr8&OoJE;(pu2vW_BrrJc>-QH9) zq8~-|UK@@qBamtyG}R%`-R;+T++aXay}w3lkugX$FPiF(E)RQ&J3`M;RL`9+f6D}< znh#C&w~HS3j9dSDp{UmJeJN@NQq7O1y1LcVzOQAUI*RHOO3{AiAk_kBs<$P3+P`C} z6Gc(&>zUPN0a7i9raFGTr~NlI9n_c)R6M`W5~NxPP4&t*p7s}=>`>!!U1h>|E0Aho zG}UXLc-rqcFdH?zQ)CzE+kjMyps7Br>1D6YUmSzt?stNB(``YjMbT7;ntIu%TgseA tQGHR|e}NrHwHTV}11~)7HLWcTP*ij7n0~_^q*@$JHTx@1`{?$MDgei)N&f%< literal 0 HcmV?d00001 diff --git a/logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 b/logs/logs/Accuracy_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.3 new file mode 100644 index 0000000000000000000000000000000000000000..1945341e16f8ba0d0acfb9bd7fa891e2b53b9552 GIT binary patch literal 1142 zcmb1OfPlsI-b$QXPDyT8UVFn)iZ`h!F*8rkwJbHS#LDRC?vP3uh%$x~2dBz%p4|#j zrzOb6&&A=GoLpLzm|Q9POGw1N<C|u#98~SiJ3>aBuMS@SsIvBkqW~kCY8y^@`!g-B zCMc>et^1^{22#z0rg~kPzP+ZR&{`DLk(&F`)IqA5(Ny2vWnsVSc3&`xYSYt&3pGHh zS<qBp*k^C=dtlc$6xGX5?z^K2Qq78{`s5Es`+P@MB^1?{&DM!)gH*GjsXqU~)t+Gq zix!IN@Q{!I9gu2vG}XQR?)F~&g|R5At#>@{&;_aHKvTVUkB5D*bQ@}zd}*yes0UKb ziKe<J%gesl{ZS{1+3VvE{L}}j=0a1wiq+fRQ7&~lifXr&2MrBDs=3ir`&{z27oR?J zD~jqXjvQG=Ak{o*s+;tE>?dAa_!33++D5HK#vs+aXsY#{eC+4#-}MwlHS<#WTP7gY zd}yj`R{Pj{Em$FdqPjikrKlN5H9wl_&a*!D1uV@8D5`taqy5Z5ss+$g`~UK>-+H+g zHN|@QWwlv=R12c1=KA4dZ(FU78YX3G=l5BHR12Z0zQ^Tj&*ZtY3dJS7^$FjtK&pk& zR5M%n+AraoUx%W)Pic|94M?>Jn(BNDU;87EmTIG@RujLQZVOT^il%y+x3B%qZ4F0J tR5Qu?FR%lt7DH1#)6LiZ_G+1HD5`(&nSR3_q*@$J^%XZ?`?^cXt^o00P38aq literal 0 HcmV?d00001 diff --git a/logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 b/logs/logs/Loss_Train/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.2 new file mode 100644 index 0000000000000000000000000000000000000000..89acad93d838280f7f49a8320892c5c3d1b5440c GIT binary patch literal 1046 zcmb1OfPlsI-b$P$9g^FX*WPfH;!P?_%*@ksElbTSu`*ic_*e)<op3%xofa<_Hy4Xf zesQrV`^*Cl3%0tpN<oZc5So2Mms9lM^^YoRZ#W7tBB`8H-|G<L*yV|$Qf%EPZ8eZe zCM1>0=Sv+z-y2*&Q5mDLFHIe!k{L;5)YMc5$IG?KC@M=%6)w~OsboP?Ii=pyVTPMy zDvHYcC-&XZ1gT_2Qt5AD=MeH}!cP>H5@zefwLvP`kW{`5)phtiNr4ST<&5Bv03DD@ zb|jVIvs4{ECEepkQE9pTd511YB?pqq{p(~LBosF?p{QhOtv{#-Qpt&=@+Fsq!{oUN zS5Q=Ljyv#EAEc5CN#)h0LJpf=T&zY>IeW!HLqm{CZX}f=;sOpQnLaH;Q7Prfk!1u@ z$%CZwf+w#-V9aU`6qR2av=$kIRPrLJd~lK5K}4bV5Q<8fCGxjSKq~o=RBGFBIV@Yi zW{je;E%2qN8Av5Rl1kqsP6x}|8~jmJK2(eLGY6>@KvMbm9f!jkmu;LVDku47wON2v z3L>dIpU2@K?vV5wMdcpV^ZP78Dus|#KD*5Ba5UaxI*Ll(x`gjmAeF*MD!10LJM6m> zI15GP1jR-AHXxNENGh*puseL7we}^7%5br}>9!!1qDU$?D6%^&)?KcFqLNj{e}NrH mr5KV*-5hp@Y1dwAqNo(xGyR4=NToQE%6q}=4xMu}zW@Ne=Nsn$ literal 0 HcmV?d00001 diff --git a/logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 b/logs/logs/Loss_Validation/events.out.tfevents.1656000220.co2-slurm-ngrtx01.97600.4 new file mode 100644 index 0000000000000000000000000000000000000000..27f3fb8fb1adb2f8cb3f308002e5aceb72ce8ff3 GIT binary patch literal 1046 zcmb1OfPlsI-b$P?f>PU+*WPfH;!P?_%*@ksElbTSu`=?wR&)_X9pg_)usSVXE^aOs zpZwxt(fJnr4o&qeW>OI27=&iu(B-^z==w*MwKp6E7?D)+@HRN8#AW|MQJKB|leQX2 zB@>cLO`b%DRa5$9P*i@>+?S>fQpt>@a!!J;!?)uTY*195I9<3<1Ei7#N#zt*Ylp|h zp6Vznn@{b#qX|;Uilp+(elv$zGPlm4s9a>ePFx$Lk_}0vb-lWSqH&xqipsYkAptrd zmF!3=FD_AaSpKZp6-DK#9nU*-K`J?rR4Q`II(+<gb~B2~vbOqzdLWgYNGj!hB^_Qo zdHf7TrCq{-pZXw`Tu3Sz*~A=DlT#Q_RGwaW(9jU1k{e0odL9u6)^+|;C@L2_aby{R zRPrFHJby;eA%~x(1x00jlhz_*kV;-8l}BC(IG9+cen(NcbgBF;6Oc+iB$XYS0uE*u z4&FggDH;4y)C{DOA4%oepZpHeJRe@8s1(<T_A>{m6hKmGbBWJ^>Beb36qU06S#1^| zm4Zkrv*P(29+w}_K~d?cetw@NNTm>x%7>nO4wsqSOi@%WuTS`H1yU)Dr1Ga7pTqXo zRkbK86_gk0+kjMxAgSc|!|M<`Q%@8{<stFA>9!!1qDU&m@AEoracEe8qH?ON{{lOZ mN--pr(vNr@tmjBNqNuFdJN<?|NToQE$`xmM9qu*S-vj`QE*qHu literal 0 HcmV?d00001 diff --git a/logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 b/logs/logs/events.out.tfevents.1655998955.co2-slurm-ngrtx01.97600.0 new file mode 100644 index 0000000000000000000000000000000000000000..b701efbbdba2cec059713e10b820c830a690ea70 GIT binary patch literal 40 rcmb1OfPlsI-b$QNFP>K^t-awW#hX-=n3<>NT9%quVrA65f$IhU-R%u1 literal 0 HcmV?d00001 -- GitLab From 503508f5ce45d98d6e974a8735143be6b2a11eb4 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.home> Date: Fri, 24 Jun 2022 14:59:32 +0200 Subject: [PATCH 6/6] test --- train.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 train.py diff --git a/train.py b/train.py new file mode 100644 index 0000000..fdf3936 --- /dev/null +++ b/train.py @@ -0,0 +1,17 @@ +import torch +from Configuration import Configuration +from Linker import * +from utils import read_csv_pgbar + +torch.cuda.empty_cache() +batch_size = int(Configuration.modelTrainingConfig['batch_size']) +nb_sentences = batch_size * 800 +epochs = int(Configuration.modelTrainingConfig['epoch']) + +file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv' +df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) + +print("Linker") +linker = Linker("models/flaubert_super_98_V2_50e.pt") +print("\nLinker Training\n") +linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True) \ No newline at end of file -- GitLab