From 9517173bdfe9a1262eadf685854384f5979586dd Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <caroline.de-pourtales@irit.fr> Date: Mon, 27 Mar 2023 15:13:14 +0200 Subject: [PATCH] change for flaubert --- SuperTagger/SuperTagger/SuperTagger.py | 10 +++++++--- predict_links.py | 2 +- train_neuralproofnet.py | 4 ++-- train_supertagger.py | 4 ++-- utils.py | 9 +++++++++ 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/SuperTagger/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger/SuperTagger.py index 4f405a0..670849d 100644 --- a/SuperTagger/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger/SuperTagger.py @@ -98,7 +98,9 @@ class SuperTagger: self.num_label = len(self.index_to_tags) self.model = Tagging_bert_model(self.bert_name, self.num_label) self.tags_tokenizer = SymbolTokenizer(self.index_to_tags) - self.sent_tokenizer = SentencesTokenizer(transformers.AutoTokenizer.from_pretrained(self.bert_name,do_lower_case=True), + self.sent_tokenizer = SentencesTokenizer(transformers.AutoTokenizer.from_pretrained( + self.bert_name, + do_lower_case=True), self.max_len_sentence) self.model.load_state_dict(params['state_dict']) self.optimizer = params['optimizer'] @@ -138,9 +140,11 @@ class SuperTagger: self.index_to_tags = index_to_tags self.bert_name = bert_name - self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(bert_name,do_lower_case=True), + self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained( + bert_name, + do_lower_case=True), self.max_len_sentence) - self.optimizer = Adam(params=self.model.parameters(), lr=2e-4, eps=1e-8) + self.optimizer = Adam(params=self.model.parameters(), lr=1e-3, eps=1e-8) self.tags_tokenizer = SymbolTokenizer(index_to_tags) self.trainable = True self.model_load = True diff --git a/predict_links.py b/predict_links.py index b52982a..28336ea 100644 --- a/predict_links.py +++ b/predict_links.py @@ -13,7 +13,7 @@ if __name__== '__main__': # region model model_tagger = "models/flaubert_super_98_V2_50e.pt" neuralproofnet = NeuralProofNet(model_tagger) - model = "Output/saved_linker.pt" + model = "models/saved_linker.pt" neuralproofnet.linker.load_weights(model) # endregion diff --git a/train_neuralproofnet.py b/train_neuralproofnet.py index f3c2284..dd39d91 100644 --- a/train_neuralproofnet.py +++ b/train_neuralproofnet.py @@ -7,7 +7,7 @@ torch.cuda.empty_cache() # region data file_path_axiom_links = 'Datasets/gold_dataset_links.csv' -df_axiom_links = read_links_csv(file_path_axiom_links)[:32] +df_axiom_links = read_links_csv(file_path_axiom_links) # endregion @@ -16,7 +16,7 @@ print("#" * 20) print("#" * 20) model_tagger = "models/flaubert_super_98_V2_50e.pt" neural_proof_net = NeuralProofNet(model_tagger) -neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0, epochs=5, pretrain_linker_epochs=5, batch_size=16, +neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=25, batch_size=16, checkpoint=True, tensorboard=True) print("#" * 20) print("#" * 20) diff --git a/train_supertagger.py b/train_supertagger.py index 2ea2aee..571edc6 100644 --- a/train_supertagger.py +++ b/train_supertagger.py @@ -15,10 +15,10 @@ index_to_super = load_obj('SuperTagger/Datasets/index_to_super') # region model tagger = SuperTagger() -tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) +tagger.create_new_model(len(index_to_super),"flaubert/flaubert_base_cased",index_to_super) ## If you want to upload a pretrained model # tagger.load_weights("models/model_check.pt") -tagger.train(texts, tags, epochs=40, batch_size=16, validation_rate=0.1, +tagger.train(texts, tags, epochs=60, batch_size=16, validation_rate=0.1, tensorboard=True, checkpoint=True) # endregion diff --git a/utils.py b/utils.py index 9640a0b..b1c9f85 100644 --- a/utils.py +++ b/utils.py @@ -20,8 +20,12 @@ def read_links_csv(csv_path, nrows=float('inf'), chunksize=100): print("\n" + "#" * 20) print("Loading csv...") + rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header chunk_list = [] + if rows > nrows: + rows = nrows + with tqdm(total=nrows, desc='Rows read: ') as bar: for chunk in pd.read_csv(csv_path, header=0, converters={'Y': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=nrows): @@ -45,7 +49,12 @@ def read_supertags_csv(csv_path, nrows=float('inf'), chunksize=100): print("\n" + "#" * 20) print("Loading csv...") + rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header chunk_list = [] + + if rows > nrows: + rows = nrows + with tqdm(total=nrows, desc='Rows read: ') as bar: for chunk in pd.read_csv(csv_path, header=0, converters={'Y1': pd.eval, 'Y2': pd.eval, 'Z': pd.eval}, chunksize=chunksize, nrows=nrows): -- GitLab