Skip to content
Snippets Groups Projects
Commit a702fd51 authored by Caroline DE POURTALES's avatar Caroline DE POURTALES
Browse files

starting train

parent c44879ab
Branches
Tags
2 merge requests!6Linker with transformer,!5Linker with transformer
...@@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer ...@@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.atom_map import atom_map
def get_atoms_from_category(category, category_to_atoms): def category_to_atoms(category, category_to_atoms):
if category in atom_map.keys(): res = [i for i in atom_map.keys() if category in i]
if len(res) > 0:
return [category] return [category]
else: else:
category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category) category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
left_side, right_side = category_cut.group(1), category_cut.group(2) left_side, right_side = category_cut.group(1), category_cut.group(2)
category_to_atoms += get_atoms_from_category(left_side, []) category_to_atoms += category_to_atoms(left_side, [])
category_to_atoms += get_atoms_from_category(right_side, []) category_to_atoms += category_to_atoms(right_side, [])
return category_to_atoms return category_to_atoms
...@@ -22,12 +23,12 @@ def get_atoms_batch(category_batch): ...@@ -22,12 +23,12 @@ def get_atoms_batch(category_batch):
for sentence in category_batch: for sentence in category_batch:
category_to_atoms = [] category_to_atoms = []
for category in sentence: for category in sentence:
category_to_atoms = get_atoms_from_category(category, category_to_atoms) category_to_atoms = category_to_atoms(category, category_to_atoms)
batch.append(category_to_atoms) batch.append(category_to_atoms)
return batch return batch
def cut_category_in_symbols(category): def category_to_atoms_polarity(category):
''' '''
Parameters : Parameters :
category : str of kind AtomCat | CategoryCat category : str of kind AtomCat | CategoryCat
...@@ -49,13 +50,13 @@ def cut_category_in_symbols(category): ...@@ -49,13 +50,13 @@ def cut_category_in_symbols(category):
if left_side in atom_map.keys(): if left_side in atom_map.keys():
category_to_polarity.append(False) category_to_polarity.append(False)
else: else:
category_to_polarity += cut_category_in_symbols(left_side) category_to_polarity += category_to_atoms_polarity(left_side)
# for the right side # for the right side
if right_side in atom_map.keys(): if right_side in atom_map.keys():
category_to_polarity.append(True) category_to_polarity.append(True)
else: else:
category_to_polarity += cut_category_in_symbols(right_side) category_to_polarity += category_to_atoms_polarity(right_side)
# dl = \ # dl = \
elif category.startswith("dl"): elif category.startswith("dl"):
...@@ -66,18 +67,18 @@ def cut_category_in_symbols(category): ...@@ -66,18 +67,18 @@ def cut_category_in_symbols(category):
if left_side in atom_map.keys(): if left_side in atom_map.keys():
category_to_polarity.append(True) category_to_polarity.append(True)
else: else:
category_to_polarity += cut_category_in_symbols(left_side) category_to_polarity += category_to_atoms_polarity(left_side)
# for the right side # for the right side
if right_side in atom_map.keys(): if right_side in atom_map.keys():
category_to_polarity.append(False) category_to_polarity.append(False)
else: else:
category_to_polarity += cut_category_in_symbols(right_side) category_to_polarity += category_to_atoms_polarity(right_side)
return category_to_polarity return category_to_polarity
def find_pos_neg_idexes(batch_symbols): def find_pos_neg_idexes(atoms_batch):
''' '''
Parameters : Parameters :
batch_symbols : (batch_size, sequence_length) the batch of symbols batch_symbols : (batch_size, sequence_length) the batch of symbols
...@@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols): ...@@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols):
(batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
''' '''
list_batch = [] list_batch = []
for sentence in batch_symbols: for sentence in atoms_batch:
list_symbols = [] list_atoms = []
for category in sentence: for category in sentence:
list_symbols.append(cut_category_in_symbols(category)) list_atoms.append(category_to_atoms_polarity(category))
list_batch.append(list_symbols) list_batch.append(list_atoms)
return list_batch return list_batch
...@@ -3,6 +3,8 @@ from torch import Tensor ...@@ -3,6 +3,8 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.functional import nll_loss, cross_entropy from torch.nn.functional import nll_loss, cross_entropy
from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
class SinkhornLoss(Module): class SinkhornLoss(Module):
def __init__(self): def __init__(self):
...@@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred): ...@@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
axiom_links_pred : (batch_size, max_atoms_type_polarity) axiom_links_pred : (batch_size, max_atoms_type_polarity)
""" """
# Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence) # Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence)
atoms_batch = get_atoms_batch(batch_axiom_links)
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe # then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
atoms_polarity = find_pos_neg_idexes(atoms_batch)
axiom_links_true = "" axiom_links_true = ""
...@@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred): ...@@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
correct_links[axiom_links_pred != axiom_links_true] = 0 correct_links[axiom_links_pred != axiom_links_true] = 0
num_correct_links = correct_links.sum().item() num_correct_links = correct_links.sum().item()
return num_correct_links return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment