Skip to content
Snippets Groups Projects
Commit 82051bbb authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

trainning

parent 2408c5e6
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ transformers = 4.16.2
[DATASET_PARAMS]
symbols_vocab_size=26
atom_vocab_size=18
max_len_sentence=290
max_len_sentence=83
max_atoms_in_sentence=875
max_atoms_in_one_type=324
......@@ -12,19 +12,20 @@ max_atoms_in_one_type=324
dim_encoder = 768
[MODEL_LINKER]
nhead=16
nhead=8
dim_emb_atom = 256
dim_feedforward_transformer = 512
num_layers=3
dim_cat_bert_out=768
dim_cat_inter=512
dim_cat_out=256
dim_intermediate_FFN=128
dim_pre_sinkhorn_transfo=64
dropout=0.1
dropout=0.15
sinkhorn_iters=5
[MODEL_TRAINING]
batch_size=16
batch_size=32
epoch=30
seed_val=42
learning_rate=2e-3
\ No newline at end of file
......@@ -66,16 +66,19 @@ class Linker(Module):
dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
# atom settings
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
# out bert
dim_cat_bert_out = int(Configuration.modelLinkerConfig['dim_cat_bert_out'])
# Transformer
self.nhead = int(Configuration.modelLinkerConfig['nhead'])
nhead = int(Configuration.modelLinkerConfig['nhead'])
self.dim_emb_atom = int(Configuration.modelLinkerConfig['dim_emb_atom'])
self.dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer'])
self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
dim_feedforward_transformer = int(Configuration.modelLinkerConfig['dim_feedforward_transformer'])
num_layers = int(Configuration.modelLinkerConfig['num_layers'])
# torch cat
self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_out'])
dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_inter'])
self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
dropout = float(Configuration.modelLinkerConfig['dropout'])
# sinkhorn
self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
# settings
......@@ -91,8 +94,10 @@ class Linker(Module):
supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger
self.Supertagger.model.to(self.device)
self.word_cat_encoder = Linear(dim_encoder*2,dim_encoder)
self.word_cat_encoder = Sequential(
Linear(dim_encoder * 2, dim_cat_bert_out),
Dropout(dropout),
LayerNorm(dim_cat_bert_out, eps=1e-8))
# Atoms embedding
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
......@@ -101,27 +106,24 @@ class Linker(Module):
self.sub_atoms_type_list = list(atom_map_redux.keys())
self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id)
self.atom_encoder.weight.data.uniform_(-0.1, 0.1)
self.position_encoder = PositionalEncoding(self.dim_emb_atom, 0.1, max_len=self.max_atoms_in_sentence)
encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead,
dim_feedforward=self.dim_feedforward_transformer, dropout=0.1)
self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
self.position_encoder = PositionalEncoding(self.dim_emb_atom, dropout, max_len=self.max_atoms_in_sentence)
encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=nhead,
dim_feedforward=dim_feedforward_transformer, dropout=dropout)
self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)
# Concatenation with word embedding
dim_cat = dim_encoder + self.dim_emb_atom
dim_cat = dim_cat_bert_out + self.dim_emb_atom
self.linker_encoder = Sequential(
FFN(dim_cat, self.dim_cat_inter, 0.1, d_out=self.dim_cat_out),
LayerNorm(self.dim_cat_out, eps=1e-8)
)
FFN(dim_cat, dim_cat_inter, dropout, d_out=self.dim_cat_out),
LayerNorm(self.dim_cat_out, eps=1e-8))
# Division into positive and negative
self.pos_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
)
FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8))
self.neg_transformation = Sequential(
FFN(self.dim_cat_out, dim_intermediate_FFN, 0.1, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
)
FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8))
# Learning
self.cross_entropy_loss = SinkhornLoss()
......@@ -143,6 +145,7 @@ class Linker(Module):
print("Start preprocess Data")
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print(sentences_tokens)
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(
......@@ -174,7 +177,8 @@ class Linker(Module):
print("End preprocess Data")
return training_dataloader, validation_dataloader
def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding, cats_embedding):
def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding,
cats_embedding):
r"""
Args:
batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
......@@ -205,7 +209,7 @@ class Linker(Module):
src_key_padding_mask=src_key_padding_mask)
atoms_embedding = atoms_embedding.permute(1, 0, 2)
# cat
# concatenation of atom encoding and word/cat encoding
atoms_sentences_encoding = torch.cat([word_cat_encoding, atoms_embedding], dim=2)
atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
......@@ -346,12 +350,12 @@ class Linker(Module):
output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
'word_embeding'], output['last_hidden_state']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
'word_embeding'], output['last_hidden_state']) # atom_vocab, batch_size, max atoms in one type, max atoms in one type
axiom_links_pred = torch.argmax(logits_predictions, dim=3) # atom_vocab, batch_size, max atoms in one type
print('\n')
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[2][1][:100])
print("Les vrais liens de la catégorie txt : ", batch_true_links[1][5][:100])
print("Les prédictions : ", axiom_links_pred[5][1][:100])
print('\n')
accuracy = mesure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
......
......@@ -5,7 +5,7 @@ import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
def __init__(self, d_model: int, dropout: float = 0.15, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
......
......@@ -201,7 +201,7 @@ def get_num_atoms_batch(category_batch, max_len_sentence):
"""
batch = []
for sentence in category_batch:
num_atoms_sentence = []
num_atoms_sentence = [0]
for category in sentence:
num_atoms_in_word = category_to_num_atoms(category, 0)
# add 1 because for word we have SEP at the end
......
numpy==1.22.2
transformers==4.16.2
torch==1.9.0
huggingface-hub==0.4.0
pandas==1.4.1
sentencepiece
git+https://gitlab.irit.fr/pnria/global-helper/deepgrail-rnn/
\ No newline at end of file
Markdown==3.3.6
packaging==21.3
scikit-learn==1.0.2
scipy==1.8.0
sentencepiece==0.1.96
tensorflow==2.9.1
tensorboard==2.8.0
torch==1.11.0
tqdm==4.64.0
transformers==4.19.0
\ No newline at end of file
......@@ -5,7 +5,7 @@ from utils import read_csv_pgbar
torch.cuda.empty_cache()
batch_size = int(Configuration.modelTrainingConfig['batch_size'])
nb_sentences = batch_size * 800
nb_sentences = batch_size * 4
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment