Skip to content
Snippets Groups Projects
Commit df0c26ca authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

adding training methods

parent fd524d4e
Branches
Tags
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -47,7 +47,6 @@ class Linker(Module):
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
# to do : definit un encoding
self.linker_encoder = AttentionDecoderLayer()
self.pos_transformation = Sequential(
......@@ -306,6 +305,7 @@ class Linker(Module):
self.linker_encoder.load_state_dict(params['linker_encoder'])
self.pos_transformation.load_state_dict(params['pos_transformation'])
self.neg_transformation.load_state_dict(params['neg_transformation'])
self.optimizer.load_state_dict(params['optimizer'])
print("\n The loading checkpoint was successful ! \n")
except Exception as e:
print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
......@@ -320,6 +320,7 @@ class Linker(Module):
'atoms_embedding': self.atoms_embedding.state_dict(),
'linker_encoder': self.linker_encoder.state_dict(),
'pos_transformation': self.pos_transformation.state_dict(),
'neg_transformation': self.neg_transformation.state_dict()
'neg_transformation': self.neg_transformation.state_dict(),
'optimizer': self.optimizer,
}, path)
self.linker.to(self.device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment