Skip to content
Snippets Groups Projects
Commit 9c45838b authored by Julien Rabault's avatar Julien Rabault
Browse files

git ignore

parent 29a9eac6
No related branches found
No related tags found
2 merge requests!6Linker with transformer,!5Linker with transformer
......@@ -3,3 +3,6 @@ Utils/silver
Utils/gold
.idea
*.pt
Linker/__pycache__
Configuration/__pycache__
__pycache__
......@@ -28,7 +28,7 @@ sinkhorn_iters=3
[MODEL_TRAINING]
device=cpu
batch_size=32
batch_size=16
epoch=20
seed_val=42
learning_rate=0.005
\ No newline at end of file
learning_rate=2e-5
\ No newline at end of file
......@@ -60,7 +60,6 @@ class Linker(Module):
self.cross_entropy_loss = SinkhornLoss()
self.optimizer = AdamW(self.parameters(),
weight_decay=1e-5,
lr=learning_rate)
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
num_warmup_steps=0,
......@@ -106,7 +105,7 @@ class Linker(Module):
return training_dataloader, validation_dataloader
def make_decoder_mask(self, atoms_token):
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64, device=self.device)
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
......
......@@ -32,12 +32,44 @@ def sub_tree_line(line_with_data: str):
return sentence, list(itertools.chain(*sub_trees))
def Txt_to_csv(file_name: str):
def Txt_to_csv(file_name: str, result_name):
file = open(file_name, "r", encoding="utf8")
text = file.readlines()
sub = [sub_tree_line(data) for data in text]
df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree'])
df.to_csv("../Datasets/" + file_name[:-4] + "_dataset_links.csv", index=False)
df.to_csv("../Datasets/" + result_name + "_dataset_links.csv", mode='a', index=False, header=False)
def Txt_to_csv_header(file_name: str, result_name):
file = open(file_name, "r", encoding="utf8")
text = file.readlines()
sub = [sub_tree_line(data) for data in text]
df = pd.DataFrame(data=sub, columns=['Sentences', 'sub_tree'])
df.to_csv("../Datasets/" + result_name + "_dataset_links.csv", index=False)
# import os
# i = 0
# path = "gold"
# for filename in os.listdir(path):
# if i == 0:
# Txt_to_csv_header(os.path.join(path, filename),path)
# else :
# Txt_to_csv(os.path.join(path, filename),path)
# i+=1
#
# i = 0
# path = "silver"
# for filename in os.listdir(path):
# if i == 0:
# Txt_to_csv_header(os.path.join(path, filename),path)
# else :
# Txt_to_csv(os.path.join(path, filename),path)
# i+=1
Txt_to_csv("aa1_links.txt")
# # reading csv files
# data1 = pd.read_csv('../Datasets/gold_dataset_links.csv')
# data2 = pd.read_csv('../Datasets/silver_dataset_links.csv')
#
# # using merge function by setting how='left'
# df = pd.merge(data1, data2,how='outer')
#
# df.to_csv("../Datasets/goldANDsilver_dataset_links.csv", index=False)
import torch
from Configuration import Configuration
from Linker import *
from Supertagger import *
from deepgrail_Tagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_csv_pgbar
torch.cuda.empty_cache()
......@@ -10,15 +10,21 @@ nb_sentences = batch_size * 200
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
nb_sentences = batch_size * 20
epochs = int(Configuration.modelTrainingConfig['epoch'])
file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
sentences_batch = df_axiom_links["Sentences"].tolist()
supertagger = SuperTagger()
supertagger.load_weights("models/model_supertagger.pt")
supertagger.load_weights("models/flaubert_super_98%_V2_50e.pt")
sents_tokenized, sents_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
print("Linker")
linker = Linker(supertagger)
linker = linker.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Linker Training")
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=True, validate=True)
linker.train_linker(df_axiom_links, sents_tokenized, sents_mask, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, validate=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment