Skip to content
Snippets Groups Projects
Commit b0e85309 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

update for osirim

parent e115a712
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
File deleted
...@@ -65,6 +65,8 @@ class Linker(Module): ...@@ -65,6 +65,8 @@ class Linker(Module):
num_warmup_steps=0, num_warmup_steps=0,
num_training_steps=100) num_training_steps=100)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0): def __preprocess_data(self, batch_size, df_axiom_links, sentences_tokens, sentences_mask, validation_rate=0.0):
r""" r"""
Args: Args:
...@@ -133,18 +135,18 @@ class Linker(Module): ...@@ -133,18 +135,18 @@ class Linker(Module):
atoms_batch_tokenized[s_idx][i] == self.atom_map[ atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and atom_type] and
atoms_polarity_batch[s_idx][i])] + [ atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)]) torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
atoms_batch_tokenized[s_idx][i] == self.atom_map[ atoms_batch_tokenized[s_idx][i] == self.atom_map[
atom_type] and atom_type] and
not atoms_polarity_batch[s_idx][i])] + [ not atoms_polarity_batch[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)]) torch.zeros(self.dim_embedding_atoms, device=self.device)]).to(self.device)
for s_idx in range(len(atoms_polarity_batch))], padding_value=0, for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -175,6 +177,7 @@ class Linker(Module): ...@@ -175,6 +177,7 @@ class Linker(Module):
training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links, training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
sentences_tokens, sentences_mask, sentences_tokens, sentences_mask,
validation_rate) validation_rate)
self.to(self.device)
for epoch_i in range(0, epochs): for epoch_i in range(0, epochs):
epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate) epoch_acc, epoch_loss = self.train_epoch(training_dataloader, validation_dataloader, checkpoint, validate)
...@@ -197,16 +200,16 @@ class Linker(Module): ...@@ -197,16 +200,16 @@ class Linker(Module):
# For each batch of training data... # For each batch of training data...
for step, batch in enumerate(training_dataloader): for step, batch in enumerate(training_dataloader):
# Unpack this training batch from our dataloader # Unpack this training batch from our dataloader
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_tokens = batch[3].to(self.device)
batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_mask = batch[4].to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
# get sentence embedding from BERT which is already trained # get sentence embedding from BERT which is already trained
logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask)
...@@ -273,18 +276,18 @@ class Linker(Module): ...@@ -273,18 +276,18 @@ class Linker(Module):
atoms_tokenized[s_idx][i] == self.atom_map[ atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and atom_type] and
polarities[s_idx][i])] + [ polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)]) torch.zeros(self.dim_embedding_atoms, device=self.device)])
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2).to(self.device)
neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
atoms_tokenized[s_idx][i] == self.atom_map[ atoms_tokenized[s_idx][i] == self.atom_map[
atom_type] and atom_type] and
not polarities[s_idx][i])] + [ not polarities[s_idx][i])] + [
torch.zeros(self.dim_embedding_atoms)]) torch.zeros(self.dim_embedding_atoms, device=self.device)])
for s_idx in range(len(polarities))], padding_value=0, for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2) max_len=self.max_atoms_in_one_type // 2).to(self.device)
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -297,13 +300,13 @@ class Linker(Module): ...@@ -297,13 +300,13 @@ class Linker(Module):
return axiom_links return axiom_links
def eval_batch(self, batch, cross_entropy_loss): def eval_batch(self, batch, cross_entropy_loss):
batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to(self.device)
batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_tokens = batch[3].to(self.device)
batch_sentences_mask = batch[4].to("cuda" if torch.cuda.is_available() else "cpu") batch_sentences_mask = batch[4].to(self.device)
logits, sentences_embedding = self.Supertagger.foward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding,
batch_sentences_mask) batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
...@@ -363,4 +366,4 @@ class Linker(Module): ...@@ -363,4 +366,4 @@ class Linker(Module):
'neg_transformation': self.neg_transformation.state_dict(), 'neg_transformation': self.neg_transformation.state_dict(),
'optimizer': self.optimizer, 'optimizer': self.optimizer,
}, path) }, path)
#self.to(self.device) self.to(self.device)
...@@ -27,9 +27,8 @@ class AttentionDecoderLayer(Module): ...@@ -27,9 +27,8 @@ class AttentionDecoderLayer(Module):
attention and feedforward operations, respectivaly. Otherwise it's done after. attention and feedforward operations, respectivaly. Otherwise it's done after.
Default: ``False`` (after). Default: ``False`` (after).
""" """
__constants__ = ['batch_first', 'norm_first']
def __init__(self) -> None: def __init__(self):
super(AttentionDecoderLayer, self).__init__() super(AttentionDecoderLayer, self).__init__()
# init params # init params
...@@ -42,18 +41,17 @@ class AttentionDecoderLayer(Module): ...@@ -42,18 +41,17 @@ class AttentionDecoderLayer(Module):
# layers # layers
self.dropout = Dropout(dropout) self.dropout = Dropout(dropout)
self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True, self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
kdim=dim_decoder, vdim=dim_decoder) kdim=dim_decoder, vdim=dim_decoder)
self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps) self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout,
kdim=dim_encoder, vdim=dim_encoder, kdim=dim_encoder, vdim=dim_encoder)
batch_first=True)
self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps) self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps)
self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout) self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout)
self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps) self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps)
def forward(self, atoms_embedding: Tensor, sents_embedding: Tensor, encoder_mask: Tensor, def forward(self, atoms_embedding, sents_embedding, encoder_mask,
decoder_mask: Tensor) -> Tensor: decoder_mask):
r"""Pass the inputs through the decoder layer. r"""Pass the inputs through the decoder layer.
Args: Args:
...@@ -62,24 +60,27 @@ class AttentionDecoderLayer(Module): ...@@ -62,24 +60,27 @@ class AttentionDecoderLayer(Module):
encoder_mask encoder_mask
decoder_mask decoder_mask
""" """
atoms_embedding = atoms_embedding.permute(1, 0, 2)
sents_embedding = sents_embedding.permute(1, 0, 2)
x = atoms_embedding x = atoms_embedding
x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask)) x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask))
x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask)) x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask))
x = self.norm3(x + self._ff_block(x)) x = self.norm3(x + self._ff_block(x))
return x return x.permute(1, 0, 2)
# self-attention block # self-attention block
def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor: def _mask_mha_block(self, x, decoder_mask):
x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0] x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0]
return x return x
# multihead attention block # multihead attention block
def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor: def _mha_block(self, x, sents_embs, encoder_mask):
x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0] x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0]
return x return x
# feed forward block # feed forward block
def _ff_block(self, x: Tensor) -> Tensor: def _ff_block(self, x):
x = self.ffn.forward(x) x = self.ffn.forward(x)
return x return x
#!/bin/sh
#SBATCH --job-name=Deepgrail_Linker
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
#SBATCH --error="error_rtx1.err"
#SBATCH --output="out_rtx1.out"
module purge
module load singularity/3.0.3
srun singularity exec /logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python "train.py"
\ No newline at end of file
numpy==1.22.2 numpy==1.22.2
transformers==4.16.2 transformers==4.16.2
torch==1.10.2 torch==1.9.0
huggingface-hub==0.4.0 huggingface-hub==0.4.0
pandas==1.4.1 pandas==1.4.1
sentencepiece sentencepiece
......
...@@ -15,9 +15,12 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) ...@@ -15,9 +15,12 @@ df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist() sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger() supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt") supertagger.load_weights("models/model_supertagger.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch) sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker") print("Linker")
linker = Linker(supertagger) linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training") print("Linker Training")
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True) linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment