import configparser import re import torch from Linker.atom_map import atom_map_redux from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch from SuperTagger.SuperTagger.SuperTagger import SuperTagger from utils import read_links_csv, read_supertags_csv, pad_sequence, load_obj def configurate_supertagger(dataset, index_to_super_path, model_tagger, nb_sentences=1000000000): print("#" * 20) print("#" * 20) print("Configuration with dataset\n") config = configparser.ConfigParser() config.read('Configuration/config.ini') df = read_supertags_csv(dataset, nb_sentences) index_to_super = load_obj(index_to_super_path) supertagger = SuperTagger() supertagger.create_new_model(len(index_to_super),model_tagger,index_to_super) sentences_batch = df["X"].str.strip().tolist() sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch) max_len_sentence = 0 for sentence in sentences_tokens: if len(sentence) > max_len_sentence: max_len_sentence = len(sentence) print("Configure parameter max len sentence to ", max_len_sentence) config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence)) with open('Configuration/config.ini', 'w') as configfile: # save config.write(configfile) print("#" * 20) print("#" * 20) def configurate_linker(dataset, model_tagger, nb_sentences=1000000000): print("#" * 20) print("#" * 20) print("Configuration with dataset\n") config = configparser.ConfigParser() config.read('Configuration/config.ini') file_path_axiom_links = dataset df_axiom_links = read_links_csv(file_path_axiom_links, nb_sentences) supertagger = SuperTagger() supertagger.load_weights(model_tagger) sentences_batch = df_axiom_links["X"].str.strip().tolist() sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch) max_len_sentence = 0 for sentence in sentences_tokens: if len(sentence) > max_len_sentence: max_len_sentence = len(sentence) print("Configure parameter max len sentence to ", max_len_sentence) config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence)) atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links) max_atoms_in_sentence = 0 for sentence in atoms_batch: if len(sentence) > max_atoms_in_sentence: max_atoms_in_sentence = len(sentence) print("Configure parameter max atoms in categories to", max_atoms_in_sentence) config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence)) atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))], max_len=max_atoms_in_sentence, padding_value=0) pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch)] for atom_type in list(atom_map_redux.keys())] max_atoms_in_on_type = 0 for atoms_type_batch in pos_idx: for sentence in atoms_type_batch: length = sentence.size(0) if length > max_atoms_in_on_type: max_atoms_in_on_type = length print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type) config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2)) with open('Configuration/config.ini', 'w') as configfile: # save config.write(configfile) print("#" * 20) print("#" * 20)