diff --git a/Linker/Linker.py b/Linker/Linker.py index ca3a8dcf4b6f00d62a315a2645ee4c59296828f8..4eb4f75347efc497c88897648ddc425571980c60 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -47,7 +47,6 @@ class Linker(Module): self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) - # to do : definit un encoding self.linker_encoder = AttentionDecoderLayer() self.pos_transformation = Sequential( @@ -306,6 +305,7 @@ class Linker(Module): self.linker_encoder.load_state_dict(params['linker_encoder']) self.pos_transformation.load_state_dict(params['pos_transformation']) self.neg_transformation.load_state_dict(params['neg_transformation']) + self.optimizer.load_state_dict(params['optimizer']) print("\n The loading checkpoint was successful ! \n") except Exception as e: print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) @@ -320,6 +320,7 @@ class Linker(Module): 'atoms_embedding': self.atoms_embedding.state_dict(), 'linker_encoder': self.linker_encoder.state_dict(), 'pos_transformation': self.pos_transformation.state_dict(), - 'neg_transformation': self.neg_transformation.state_dict() + 'neg_transformation': self.neg_transformation.state_dict(), + 'optimizer': self.optimizer, }, path) self.linker.to(self.device)