Skip to content
Snippets Groups Projects
find_config.py 3.71 KiB
import configparser
import re

import torch

from Linker.atom_map import atom_map_redux
from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
from SuperTagger.SuperTagger.SuperTagger import SuperTagger
from utils import read_links_csv, read_supertags_csv, pad_sequence, load_obj

def configurate_supertagger(dataset, index_to_super_path, model_tagger, nb_sentences=1000000000):
    print("#" * 20)
    print("#" * 20)
    print("Configuration with dataset\n")
    config = configparser.ConfigParser()
    config.read('Configuration/config.ini')

    df = read_supertags_csv(dataset, nb_sentences)
    index_to_super = load_obj(index_to_super_path)

    supertagger = SuperTagger()
    supertagger.create_new_model(len(index_to_super),model_tagger,index_to_super)
    
    sentences_batch = df["X"].str.strip().tolist()
    sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch)
    max_len_sentence = 0
    for sentence in sentences_tokens:
        if len(sentence) > max_len_sentence:
            max_len_sentence = len(sentence)
    print("Configure parameter max len sentence to ", max_len_sentence)
    config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))

    with open('Configuration/config.ini', 'w') as configfile:  # save
        config.write(configfile)

    print("#" * 20)
    print("#" * 20)

    
def configurate_linker(dataset, model_tagger, nb_sentences=1000000000):
    print("#" * 20)
    print("#" * 20)
    print("Configuration with dataset\n")
    config = configparser.ConfigParser()
    config.read('Configuration/config.ini')

    file_path_axiom_links = dataset
    df_axiom_links = read_links_csv(file_path_axiom_links, nb_sentences)

    supertagger = SuperTagger()
    supertagger.load_weights(model_tagger)
    
    sentences_batch = df_axiom_links["X"].str.strip().tolist()
    sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform(sentences_batch)
    max_len_sentence = 0
    for sentence in sentences_tokens:
        if len(sentence) > max_len_sentence:
            max_len_sentence = len(sentence)
    print("Configure parameter max len sentence to ", max_len_sentence)
    config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))

    atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
    max_atoms_in_sentence = 0
    for sentence in atoms_batch:
        if len(sentence) > max_atoms_in_sentence:
            max_atoms_in_sentence = len(sentence)
    print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
    config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))

    atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
                 max_len=max_atoms_in_sentence, padding_value=0)
    pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
                 bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
                 and atoms_polarity_batch[s_idx][i]])
                for s_idx, sentence in enumerate(atoms_batch)]
               for atom_type in list(atom_map_redux.keys())]
    max_atoms_in_on_type = 0
    for atoms_type_batch in pos_idx:
        for sentence in atoms_type_batch:
            length = sentence.size(0)
            if length > max_atoms_in_on_type:
                max_atoms_in_on_type = length
    print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
    config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))

    with open('Configuration/config.ini', 'w') as configfile:  # save
        config.write(configfile)

    print("#" * 20)
    print("#" * 20)