Skip to content
Snippets Groups Projects
Commit 7413c718 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

merging

parent 99955583
No related branches found
No related tags found
No related merge requests found
import torch
from utils import pad_sequence
class AtomTokenizer(object):
r"""
Tokenizer for the atoms with padding
"""
def __init__(self, atom_map, max_atoms_in_sentence):
self.atom_map = atom_map
self.max_atoms_in_sentence = max_atoms_in_sentence
self.inverse_atom_map = {v: k for k, v in self.atom_map.items()}
self.pad_token = '[PAD]'
self.pad_token_id = self.atom_map[self.pad_token]
def __len__(self):
return len(self.atom_map)
def convert_atoms_to_ids(self, atom):
r"""
Convert a atom to its id
:param atom: atom string
:return: atom id
"""
return self.atom_map[str(atom)]
def convert_sents_to_ids(self, sentences):
r"""
Convert sentences to ids
:param sentences: List of atoms in a sentence
:return: List of atoms'ids
"""
return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
def convert_batchs_to_ids(self, batchs_sentences):
r"""
Convert a batch of sentences of atoms to the ids
:param batchs_sentences: batch of sentences atoms
:return: list of list of atoms'ids
"""
return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
def convert_ids_to_atoms(self, ids):
r"""
Translate id to atom
:param ids: atom id
:return: atom string
"""
return [self.inverse_atom_map[int(i)] for i in ids]
This diff is collapsed.
import torch
from torch import nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [batch_size, seq_len, mbedding_dim]
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
from torch import logsumexp
def norm(x, dim):
return x - logsumexp(x, dim=dim, keepdim=True)
def sinkhorn_step(x):
return norm(norm(x, dim=1), dim=2)
def sinkhorn_fn_no_exp(x, tau=1, iters=3):
x = x / tau
for _ in range(iters):
x = sinkhorn_step(x)
return x
from .Linker import Linker
from .atom_map import atom_map
from .AtomTokenizer import AtomTokenizer
from .PositionalEncoding import PositionalEncoding
from .Sinkhorn import *
\ No newline at end of file
atom_map = \
{'cl_r': 0,
"pp": 1,
'n': 2,
's_ppres': 3,
's_whq': 4,
's_q': 5,
'np': 6,
's_inf': 7,
's_pass': 8,
'pp_a': 9,
'pp_par': 10,
'pp_de': 11,
'cl_y': 12,
'txt': 13,
's': 14,
's_ppart': 15,
"[SEP]":16,
'[PAD]': 17
}
atom_map_redux = {
'cl_r': 0,
'pp': 1,
'n': 2,
'np': 3,
'cl_y': 4,
'txt': 5,
's': 6
}
import torch
from torch.nn import Module
from torch.nn.functional import nll_loss
from Linker.atom_map import atom_map, atom_map_redux
class SinkhornLoss(Module):
r"""
Loss for the linker
"""
def __init__(self):
super(SinkhornLoss, self).__init__()
def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def measure_accuracy(batch_true_links, axiom_links_pred):
r"""
batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
padding = -1
batch_true_links = batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == padding] = 1
num_correct_links = correct_links.sum().item()
num_masked_atoms = len(batch_true_links[batch_true_links == padding])
# diviser par nombre de links
return (num_correct_links - num_masked_atoms) / (
axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
import re
import pandas as pd
import regex
import torch
from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module
from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence
class FFN(Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1, d_out=None):
super(FFN, self).__init__()
self.ffn = Sequential(
Linear(d_model, d_ff, bias=False),
GELU(),
Dropout(dropout),
Linear(d_ff, d_out if d_out is not None else d_model, bias=False)
)
def forward(self, x):
return self.ffn(x)
################################ Regex ########################################
regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
# region get true axiom links
def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
r"""
Args:
max_atoms_in_one_type : configuration
atoms_polarity : (batch_size, max_atoms_in_sentence)
batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
Returns:
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
atoms_batch = get_atoms_links_batch(batch_axiom_links)
linking_plus_to_minus_all_types = []
for atom_type in list(atom_map_redux.keys()):
# filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
in range(len(atoms_batch))]
l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
in range(len(atoms_batch))]
linking_plus_to_minus = pad_sequence(
[torch.as_tensor(
[l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1
for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
padding_value=-1)
linking_plus_to_minus_all_types.append(linking_plus_to_minus)
return torch.stack(linking_plus_to_minus_all_types)
def category_to_atoms_axiom_links(category, categories_to_atoms):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
"""
res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms_axiom_links(cat, categories_to_atoms)
elif True in res:
return [category]
else:
category_cut = regex.match(regex_categories_axiom_links, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms_axiom_links(cat, [])
return categories_to_atoms
def get_atoms_links_batch(category_batch):
r"""
Args:
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
for category in sentence:
if category != "let" and not category.startswith("GOAL:"):
categories_to_atoms += category_to_atoms_axiom_links(category, [])
categories_to_atoms.append("[SEP]")
elif category.startswith("GOAL:"):
categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms
batch.append(categories_to_atoms)
return batch
print("test to create links ",
get_axiom_links(20, torch.stack([torch.as_tensor(
[True, False, True, False, False, False, True, False, True, False,
False, True, False, False, False, True, False, False, True, False,
True, False, False, True, False, False, False, False, False, False])]),
[['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
# endregion
# region get atoms in sentence
def category_to_atoms(category, categories_to_atoms):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
categories_to_atoms : recursive list
Returns:
List of atoms inside the category in prefix order
"""
res = [(category == atom_type) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_atoms(cat, categories_to_atoms)
elif True in res:
return [category]
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_atoms(cat, [])
return categories_to_atoms
def get_atoms_batch(category_batch):
r"""
Args:
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns:
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = []
for sentence in category_batch:
categories_to_atoms = []
for category in sentence:
if category != "let":
categories_to_atoms += category_to_atoms(category, [])
categories_to_atoms.append("[SEP]")
batch.append(categories_to_atoms)
return batch
print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']",
get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
# endregion
# region calculate num atoms per category
def category_to_num_atoms(category, categories_to_atoms):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
categories_to_atoms : recursive int
Returns:
List of atoms inside the category in prefix order
"""
res = [(category == atom_type) for atom_type in atom_map.keys()]
if category.startswith("GOAL:"):
word, cat = category.split(':')
return category_to_num_atoms(cat, 0)
elif category == "let":
return 0
elif True in res:
return 1
else:
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
for cat in category_cut:
categories_to_atoms += category_to_num_atoms(cat, 0)
return categories_to_atoms
def get_num_atoms_batch(category_batch, max_len_sentence):
r"""
Args:
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
max_len_sentence : max_len_sentence parameter
Returns:
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch = []
for sentence in category_batch:
num_atoms_sentence = [0]
for category in sentence:
num_atoms_in_word = category_to_num_atoms(category, 0)
# add 1 because for word we have SEP at the end
if category != "let":
num_atoms_in_word += 1
num_atoms_sentence.append(num_atoms_in_word)
batch.append(torch.as_tensor(num_atoms_sentence))
return pad_sequence(batch, max_len=max_len_sentence, padding_value=0)
print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']",
get_num_atoms_batch([["dr(0,s,np)", "let"]], 10))
# endregion
# region get polarity
def category_to_atoms_polarity(category, polarity):
r"""
Args:
category : str of kind AtomCat | CategoryCat(dr or dl)
polarity : polarity according to recursivity
Returns:
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
"""
category_to_polarity = []
res = [(category == atom_type) for atom_type in atom_map.keys()]
# mot final
if category.startswith("GOAL:"):
word, cat = category.split(':')
res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()]
if True in res:
category_to_polarity.append(True)
else:
category_to_polarity += category_to_atoms_polarity(cat, True)
# le mot a une category atomique
elif True in res:
category_to_polarity.append(not polarity)
# sinon c'est une formule longue
else:
# dr = /
if category.startswith("dr"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, polarity)
# for the right side : change polarity for next right formula
category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
# dl = \
elif category.startswith("dl"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# p
elif category.startswith("p"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
left_side, right_side = category_cut[0], category_cut[1]
# for the left side
category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
# for the right side
category_to_polarity += category_to_atoms_polarity(right_side, polarity)
# box
elif category.startswith("box"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
# dia
elif category.startswith("dia"):
category_cut = regex.match(regex_categories, category).groups()
category_cut = [cat for cat in category_cut if cat is not None]
category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
return category_to_polarity
def find_pos_neg_idexes(atoms_batch):
r"""
Args:
atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns:
(batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
"""
list_batch = []
for sentence in atoms_batch:
list_atoms = []
for category in sentence:
if category == "let":
pass
else:
for at in category_to_atoms_polarity(category, True):
list_atoms.append(at)
list_atoms.append(False)
list_batch.append(list_atoms)
return list_batch
print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n",
find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
# endregion
# region get atoms and polarities with GOAL
def get_GOAL(max_len_sentence, df_axiom_links):
categories_batch = df_axiom_links["Z"]
categories_with_goal = df_axiom_links["Y"]
polarities = find_pos_neg_idexes(categories_batch)
atoms_batch = get_atoms_batch(categories_batch)
num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
for s_idx in range(len(atoms_batch)):
goal = categories_with_goal[s_idx][-1]
polarities_goal = category_to_atoms_polarity(goal, True)
goal = re.search(r"(\w+)_\d+", goal).groups()[0]
atoms = category_to_atoms(goal, [])
atoms_batch[s_idx] = atoms + atoms_batch[s_idx] # + ["[SEP]"]
polarities[s_idx] = polarities_goal + polarities[s_idx] # + False
num_atoms_batch[s_idx][0] += len(atoms) # +1
return atoms_batch, polarities, num_atoms_batch
df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
"Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
'GOAL:np_7']]})
print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
# endregion
# region get idx for pos and neg
def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
not atoms_polarity_batch[s_idx][i]])
for s_idx, sentence in enumerate(atoms_batch)],
max_len=max_atoms_in_one_type // 2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
torch.as_tensor(
[[True, True, False, False,
True, False, False, False,
False, False,
False, False]]), 10))
# endregion
\ No newline at end of file
...@@ -17,11 +17,18 @@ Clone the project locally. ...@@ -17,11 +17,18 @@ Clone the project locally.
### Libraries installation ### Libraries installation
Run the init.sh script and install the Tagger project under SuperTagger name and the Linker directory in Linker project under Linker name. Run the following script :
Upload the tagger.pt in models. (You may need to modify 'model_tagger' in train.py.) ```bash
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
You can upload a linker model, so there is no pretraining, you just need to give it to the Proof net initialization. mkdir Output
mkdir TensorBoard
```
Optional : Upload the tagger.pt and linker.pt in models. (You may need to modify 'model_tagger' in train.py.)
### Structure ### Structure
......
...@@ -5,7 +5,6 @@ packaging==21.3 ...@@ -5,7 +5,6 @@ packaging==21.3
scikit-learn==1.0.2 scikit-learn==1.0.2
scipy==1.8.0 scipy==1.8.0
sentencepiece==0.1.96 sentencepiece==0.1.96
tensorflow==2.9.1
tensorboard==2.8.0 tensorboard==2.8.0
torch==1.11.0 torch==1.11.0
tqdm==4.64.0 tqdm==4.64.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment