Skip to content
Snippets Groups Projects
Commit 9517173b authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

change for flaubert

parent d891f64e
No related branches found
No related tags found
No related merge requests found
...@@ -98,7 +98,9 @@ class SuperTagger: ...@@ -98,7 +98,9 @@ class SuperTagger:
self.num_label = len(self.index_to_tags) self.num_label = len(self.index_to_tags)
self.model = Tagging_bert_model(self.bert_name, self.num_label) self.model = Tagging_bert_model(self.bert_name, self.num_label)
self.tags_tokenizer = SymbolTokenizer(self.index_to_tags) self.tags_tokenizer = SymbolTokenizer(self.index_to_tags)
self.sent_tokenizer = SentencesTokenizer(transformers.AutoTokenizer.from_pretrained(self.bert_name,do_lower_case=True), self.sent_tokenizer = SentencesTokenizer(transformers.AutoTokenizer.from_pretrained(
self.bert_name,
do_lower_case=True),
self.max_len_sentence) self.max_len_sentence)
self.model.load_state_dict(params['state_dict']) self.model.load_state_dict(params['state_dict'])
self.optimizer = params['optimizer'] self.optimizer = params['optimizer']
...@@ -138,9 +140,11 @@ class SuperTagger: ...@@ -138,9 +140,11 @@ class SuperTagger:
self.index_to_tags = index_to_tags self.index_to_tags = index_to_tags
self.bert_name = bert_name self.bert_name = bert_name
self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(bert_name,do_lower_case=True), self.sent_tokenizer = SentencesTokenizer(AutoTokenizer.from_pretrained(
bert_name,
do_lower_case=True),
self.max_len_sentence) self.max_len_sentence)
self.optimizer = Adam(params=self.model.parameters(), lr=2e-4, eps=1e-8) self.optimizer = Adam(params=self.model.parameters(), lr=1e-3, eps=1e-8)
self.tags_tokenizer = SymbolTokenizer(index_to_tags) self.tags_tokenizer = SymbolTokenizer(index_to_tags)
self.trainable = True self.trainable = True
self.model_load = True self.model_load = True
......
...@@ -13,7 +13,7 @@ if __name__== '__main__': ...@@ -13,7 +13,7 @@ if __name__== '__main__':
# region model # region model
model_tagger = "models/flaubert_super_98_V2_50e.pt" model_tagger = "models/flaubert_super_98_V2_50e.pt"
neuralproofnet = NeuralProofNet(model_tagger) neuralproofnet = NeuralProofNet(model_tagger)
model = "Output/saved_linker.pt" model = "models/saved_linker.pt"
neuralproofnet.linker.load_weights(model) neuralproofnet.linker.load_weights(model)
# endregion # endregion
......
...@@ -7,7 +7,7 @@ torch.cuda.empty_cache() ...@@ -7,7 +7,7 @@ torch.cuda.empty_cache()
# region data # region data
file_path_axiom_links = 'Datasets/gold_dataset_links.csv' file_path_axiom_links = 'Datasets/gold_dataset_links.csv'
df_axiom_links = read_links_csv(file_path_axiom_links)[:32] df_axiom_links = read_links_csv(file_path_axiom_links)
# endregion # endregion
...@@ -16,7 +16,7 @@ print("#" * 20) ...@@ -16,7 +16,7 @@ print("#" * 20)
print("#" * 20) print("#" * 20)
model_tagger = "models/flaubert_super_98_V2_50e.pt" model_tagger = "models/flaubert_super_98_V2_50e.pt"
neural_proof_net = NeuralProofNet(model_tagger) neural_proof_net = NeuralProofNet(model_tagger)
neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0, epochs=5, pretrain_linker_epochs=5, batch_size=16, neural_proof_net.train_neuralproofnet(df_axiom_links, validation_rate=0.1, epochs=25, pretrain_linker_epochs=25, batch_size=16,
checkpoint=True, tensorboard=True) checkpoint=True, tensorboard=True)
print("#" * 20) print("#" * 20)
print("#" * 20) print("#" * 20)
......
...@@ -15,10 +15,10 @@ index_to_super = load_obj('SuperTagger/Datasets/index_to_super') ...@@ -15,10 +15,10 @@ index_to_super = load_obj('SuperTagger/Datasets/index_to_super')
# region model # region model
tagger = SuperTagger() tagger = SuperTagger()
tagger.create_new_model(len(index_to_super),'camembert-base',index_to_super) tagger.create_new_model(len(index_to_super),"flaubert/flaubert_base_cased",index_to_super)
## If you want to upload a pretrained model ## If you want to upload a pretrained model
# tagger.load_weights("models/model_check.pt") # tagger.load_weights("models/model_check.pt")
tagger.train(texts, tags, epochs=40, batch_size=16, validation_rate=0.1, tagger.train(texts, tags, epochs=60, batch_size=16, validation_rate=0.1,
tensorboard=True, checkpoint=True) tensorboard=True, checkpoint=True)
# endregion # endregion
......
...@@ -20,8 +20,12 @@ def read_links_csv(csv_path, nrows=float('inf'), chunksize=100): ...@@ -20,8 +20,12 @@ def read_links_csv(csv_path, nrows=float('inf'), chunksize=100):
print("\n" + "#" * 20) print("\n" + "#" * 20)
print("Loading csv...") print("Loading csv...")
rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header
chunk_list = [] chunk_list = []
if rows > nrows:
rows = nrows
with tqdm(total=nrows, desc='Rows read: ') as bar: with tqdm(total=nrows, desc='Rows read: ') as bar:
for chunk in pd.read_csv(csv_path, header=0, converters={'Y': pd.eval, 'Z': pd.eval}, for chunk in pd.read_csv(csv_path, header=0, converters={'Y': pd.eval, 'Z': pd.eval},
chunksize=chunksize, nrows=nrows): chunksize=chunksize, nrows=nrows):
...@@ -45,7 +49,12 @@ def read_supertags_csv(csv_path, nrows=float('inf'), chunksize=100): ...@@ -45,7 +49,12 @@ def read_supertags_csv(csv_path, nrows=float('inf'), chunksize=100):
print("\n" + "#" * 20) print("\n" + "#" * 20)
print("Loading csv...") print("Loading csv...")
rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1 # minus the header
chunk_list = [] chunk_list = []
if rows > nrows:
rows = nrows
with tqdm(total=nrows, desc='Rows read: ') as bar: with tqdm(total=nrows, desc='Rows read: ') as bar:
for chunk in pd.read_csv(csv_path, header=0, converters={'Y1': pd.eval, 'Y2': pd.eval, 'Z': pd.eval}, for chunk in pd.read_csv(csv_path, header=0, converters={'Y1': pd.eval, 'Y2': pd.eval, 'Z': pd.eval},
chunksize=chunksize, nrows=nrows): chunksize=chunksize, nrows=nrows):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment