From df0c26ca1906cca165adda762cea2b78e6b20b71 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Fri, 13 May 2022 16:53:29 +0200
Subject: [PATCH] adding training methods

---
 Linker/Linker.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/Linker/Linker.py b/Linker/Linker.py
index ca3a8dc..4eb4f75 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -47,7 +47,6 @@ class Linker(Module):
         self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
         self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
 
-        # to do : definit un encoding
         self.linker_encoder = AttentionDecoderLayer()
 
         self.pos_transformation = Sequential(
@@ -306,6 +305,7 @@ class Linker(Module):
             self.linker_encoder.load_state_dict(params['linker_encoder'])
             self.pos_transformation.load_state_dict(params['pos_transformation'])
             self.neg_transformation.load_state_dict(params['neg_transformation'])
+            self.optimizer.load_state_dict(params['optimizer'])
             print("\n The loading checkpoint was successful ! \n")
         except Exception as e:
             print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
@@ -320,6 +320,7 @@ class Linker(Module):
             'atoms_embedding': self.atoms_embedding.state_dict(),
             'linker_encoder': self.linker_encoder.state_dict(),
             'pos_transformation': self.pos_transformation.state_dict(),
-            'neg_transformation': self.neg_transformation.state_dict()
+            'neg_transformation': self.neg_transformation.state_dict(),
+            'optimizer': self.optimizer,
         }, path)
         self.linker.to(self.device)
-- 
GitLab