diff --git a/Configuration/config.ini b/Configuration/config.ini index b33d6df2ef05dec270bf4a24effef2321c5b83a9..cd5dbae0ed1ccd382ec1bf1c03580cf267829ba6 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -4,7 +4,7 @@ transformers = 4.16.2 [DATASET_PARAMS] symbols_vocab_size=26 atom_vocab_size=18 -max_len_sentence=83 +max_len_sentence=290 max_atoms_in_sentence=875 max_atoms_in_one_type=324 @@ -12,7 +12,7 @@ max_atoms_in_one_type=324 dim_encoder = 768 [MODEL_LINKER] -nhead=16 +nhead=8 dim_emb_atom = 256 dim_feedforward_transformer = 768 num_layers=3 @@ -25,6 +25,6 @@ sinkhorn_iters=5 [MODEL_TRAINING] batch_size=32 -epoch=25 +epoch=30 seed_val=42 learning_rate=2e-3 \ No newline at end of file diff --git a/Linker/DataParallelLinker.py b/Linker/DataParallelLinker.py deleted file mode 100644 index 5885845bdb38e7591ae95c42460dfcd2f88df7b1..0000000000000000000000000000000000000000 --- a/Linker/DataParallelLinker.py +++ /dev/null @@ -1,64 +0,0 @@ -import datetime - -from torch.nn import DataParallel, Module -from Linker import * - - -class DataParallelModel(Module): - - def __init__(self): - super().__init__() - self.linker = DataParallel(Linker("models/flaubert_super_98_V2_50e.pt")) - - def forward(self, x): - x = self.linker(x) - return x - - def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, - batch_size=32, checkpoint=True, tensorboard=False): - r""" - Args: - df_axiom_links : pandas dataFrame containing the atoms anoted with _i - validation_rate : float - epochs : int - batch_size : int - checkpoint : boolean - tensorboard : boolean - Returns: - Final accuracy and final loss - """ - training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, - validation_rate) - if checkpoint or tensorboard: - checkpoint_dir, writer = output_create_dir() - - for epoch_i in range(epochs): - print("") - print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) - print('Training...') - avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader) - - print("") - print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}') - print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%') - - if validation_rate > 0.0: - loss_test, accuracy_test = self.eval_epoch(validation_dataloader) - print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%') - - if checkpoint: - self.__checkpoint_save( - path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt')) - - if tensorboard: - writer.add_scalars(f'Accuracy', { - 'Train': avg_accuracy_train}, epoch_i) - writer.add_scalars(f'Loss', { - 'Train': avg_train_loss}, epoch_i) - if validation_rate > 0.0: - writer.add_scalars(f'Accuracy', { - 'Validation': accuracy_test}, epoch_i) - writer.add_scalars(f'Loss', { - 'Validation': loss_test}, epoch_i) - - print('\n') \ No newline at end of file diff --git a/Linker/Linker.py b/Linker/Linker.py index 15e775cc738eeb4e95004202af4d2d6baf30328f..3012c7e1e58f9693d7556acbccae6611e85b2ec0 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -22,7 +22,7 @@ from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.AtomTokenizer import AtomTokenizer from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx from Supertagger import SuperTagger from utils import pad_sequence @@ -149,7 +149,7 @@ class Linker(Module): num_atoms_per_word = get_num_atoms_batch(df_axiom_links["Z"], self.max_len_sentence) pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) - neg_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) + neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence) truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["Y"])