-
Caroline de Pourtalès authoredCaroline de Pourtalès authored
find_config.py 3.71 KiB
import configparser
import re
import torch
from Linker.atom_map import atom_map_redux
from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
from SuperTagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_links_csv, read_supertags_csv, pad_sequence, load_obj
def configurate_supertagger(dataset, index_to_super_path, model_tagger, nb_sentences=1000000000):
print("#" * 20)
print("#" * 20)
print("Configuration with dataset\n")
config = configparser.ConfigParser()
config.read('Configuration/config.ini')
df = read_supertags_csv(dataset, nb_sentences)
index_to_super = load_obj(index_to_super_path)
supertagger = SuperTagger()
supertagger.create_new_model(len(index_to_super),model_tagger,index_to_super)
sentences_batch = df["X"].str.strip().tolist()
sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch)
max_len_sentence = 0
for sentence in sentences_tokens:
if len(sentence) > max_len_sentence:
max_len_sentence = len(sentence)
print("Configure parameter max len sentence to ", max_len_sentence)
config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
with open('Configuration/config.ini', 'w') as configfile: # save
config.write(configfile)
print("#" * 20)
print("#" * 20)
def configurate_linker(dataset, model_tagger, nb_sentences=1000000000):
print("#" * 20)
print("#" * 20)
print("Configuration with dataset\n")
config = configparser.ConfigParser()
config.read('Configuration/config.ini')
file_path_axiom_links = dataset
df_axiom_links = read_links_csv(file_path_axiom_links, nb_sentences)
supertagger = SuperTagger()
supertagger.load_weights(model_tagger)
sentences_batch = df_axiom_links["X"].str.strip().tolist()
sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch)
max_len_sentence = 0
for sentence in sentences_tokens:
if len(sentence) > max_len_sentence:
max_len_sentence = len(sentence)
print("Configure parameter max len sentence to ", max_len_sentence)
config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
max_atoms_in_sentence = 0
for sentence in atoms_batch:
if len(sentence) > max_atoms_in_sentence:
max_atoms_in_sentence = len(sentence)
print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))
atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
max_len=max_atoms_in_sentence, padding_value=0)
pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
and atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)]
for atom_type in list(atom_map_redux.keys())]
max_atoms_in_on_type = 0
for atoms_type_batch in pos_idx:
for sentence in atoms_type_batch:
length = sentence.size(0)
if length > max_atoms_in_on_type:
max_atoms_in_on_type = length
print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))
with open('Configuration/config.ini', 'w') as configfile: # save
config.write(configfile)
print("#" * 20)
print("#" * 20)