From df0c26ca1906cca165adda762cea2b78e6b20b71 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Fri, 13 May 2022 16:53:29 +0200 Subject: [PATCH] adding training methods --- Linker/Linker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Linker/Linker.py b/Linker/Linker.py index ca3a8dc..4eb4f75 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -47,7 +47,6 @@ class Linker(Module): self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) - # to do : definit un encoding self.linker_encoder = AttentionDecoderLayer() self.pos_transformation = Sequential( @@ -306,6 +305,7 @@ class Linker(Module): self.linker_encoder.load_state_dict(params['linker_encoder']) self.pos_transformation.load_state_dict(params['pos_transformation']) self.neg_transformation.load_state_dict(params['neg_transformation']) + self.optimizer.load_state_dict(params['optimizer']) print("\n The loading checkpoint was successful ! \n") except Exception as e: print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) @@ -320,6 +320,7 @@ class Linker(Module): 'atoms_embedding': self.atoms_embedding.state_dict(), 'linker_encoder': self.linker_encoder.state_dict(), 'pos_transformation': self.pos_transformation.state_dict(), - 'neg_transformation': self.neg_transformation.state_dict() + 'neg_transformation': self.neg_transformation.state_dict(), + 'optimizer': self.optimizer, }, path) self.linker.to(self.device) -- GitLab