Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DeepGrail Linker
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
PNRIA
Global Helper
DeepGrail Linker
Commits
a702fd51
Commit
a702fd51
authored
3 years ago
by
Caroline DE POURTALES
Browse files
Options
Downloads
Patches
Plain Diff
starting train
parent
c44879ab
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!6
Linker with transformer
,
!5
Linker with transformer
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
SuperTagger/Linker/utils.py
+16
-17
16 additions, 17 deletions
SuperTagger/Linker/utils.py
SuperTagger/eval.py
+5
-1
5 additions, 1 deletion
SuperTagger/eval.py
with
21 additions
and
18 deletions
SuperTagger/Linker/utils.py
+
16
−
17
View file @
a702fd51
...
@@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
...
@@ -4,15 +4,16 @@ from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
from
SuperTagger.Linker.atom_map
import
atom_map
from
SuperTagger.Linker.atom_map
import
atom_map
def
get_atoms_from_category
(
category
,
category_to_atoms
):
def
category_to_atoms
(
category
,
category_to_atoms
):
if
category
in
atom_map
.
keys
():
res
=
[
i
for
i
in
atom_map
.
keys
()
if
category
in
i
]
if
len
(
res
)
>
0
:
return
[
category
]
return
[
category
]
else
:
else
:
category_cut
=
re
.
search
(
r
'
\w*\(\d+,(.+),(.+)\)
'
,
category
)
category_cut
=
re
.
search
(
r
'
\w*\(\d+,(.+),(.+)\)
'
,
category
)
left_side
,
right_side
=
category_cut
.
group
(
1
),
category_cut
.
group
(
2
)
left_side
,
right_side
=
category_cut
.
group
(
1
),
category_cut
.
group
(
2
)
category_to_atoms
+=
get_atoms_from_category
(
left_side
,
[])
category_to_atoms
+=
category_to_atoms
(
left_side
,
[])
category_to_atoms
+=
get_atoms_from_category
(
right_side
,
[])
category_to_atoms
+=
category_to_atoms
(
right_side
,
[])
return
category_to_atoms
return
category_to_atoms
...
@@ -22,12 +23,12 @@ def get_atoms_batch(category_batch):
...
@@ -22,12 +23,12 @@ def get_atoms_batch(category_batch):
for
sentence
in
category_batch
:
for
sentence
in
category_batch
:
category_to_atoms
=
[]
category_to_atoms
=
[]
for
category
in
sentence
:
for
category
in
sentence
:
category_to_atoms
=
get_atoms_from_category
(
category
,
category_to_atoms
)
category_to_atoms
=
category_to_atoms
(
category
,
category_to_atoms
)
batch
.
append
(
category_to_atoms
)
batch
.
append
(
category_to_atoms
)
return
batch
return
batch
def
cut_
category_
in_symbols
(
category
):
def
category_
to_atoms_polarity
(
category
):
'''
'''
Parameters :
Parameters :
category : str of kind AtomCat | CategoryCat
category : str of kind AtomCat | CategoryCat
...
@@ -49,13 +50,13 @@ def cut_category_in_symbols(category):
...
@@ -49,13 +50,13 @@ def cut_category_in_symbols(category):
if
left_side
in
atom_map
.
keys
():
if
left_side
in
atom_map
.
keys
():
category_to_polarity
.
append
(
False
)
category_to_polarity
.
append
(
False
)
else
:
else
:
category_to_polarity
+=
cut_
category_
in_symbols
(
left_side
)
category_to_polarity
+=
category_
to_atoms_polarity
(
left_side
)
# for the right side
# for the right side
if
right_side
in
atom_map
.
keys
():
if
right_side
in
atom_map
.
keys
():
category_to_polarity
.
append
(
True
)
category_to_polarity
.
append
(
True
)
else
:
else
:
category_to_polarity
+=
cut_
category_
in_symbols
(
right_side
)
category_to_polarity
+=
category_
to_atoms_polarity
(
right_side
)
# dl = \
# dl = \
elif
category
.
startswith
(
"
dl
"
):
elif
category
.
startswith
(
"
dl
"
):
...
@@ -66,18 +67,18 @@ def cut_category_in_symbols(category):
...
@@ -66,18 +67,18 @@ def cut_category_in_symbols(category):
if
left_side
in
atom_map
.
keys
():
if
left_side
in
atom_map
.
keys
():
category_to_polarity
.
append
(
True
)
category_to_polarity
.
append
(
True
)
else
:
else
:
category_to_polarity
+=
cut_
category_
in_symbols
(
left_side
)
category_to_polarity
+=
category_
to_atoms_polarity
(
left_side
)
# for the right side
# for the right side
if
right_side
in
atom_map
.
keys
():
if
right_side
in
atom_map
.
keys
():
category_to_polarity
.
append
(
False
)
category_to_polarity
.
append
(
False
)
else
:
else
:
category_to_polarity
+=
cut_
category_
in_symbols
(
right_side
)
category_to_polarity
+=
category_
to_atoms_polarity
(
right_side
)
return
category_to_polarity
return
category_to_polarity
def
find_pos_neg_idexes
(
batch_symbols
):
def
find_pos_neg_idexes
(
atoms_batch
):
'''
'''
Parameters :
Parameters :
batch_symbols : (batch_size, sequence_length) the batch of symbols
batch_symbols : (batch_size, sequence_length) the batch of symbols
...
@@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols):
...
@@ -86,11 +87,9 @@ def find_pos_neg_idexes(batch_symbols):
(batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
(batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
'''
'''
list_batch
=
[]
list_batch
=
[]
for
sentence
in
batch_symbols
:
for
sentence
in
atoms_batch
:
list_
symbol
s
=
[]
list_
atom
s
=
[]
for
category
in
sentence
:
for
category
in
sentence
:
list_
symbol
s
.
append
(
cut_
category_
in_symbols
(
category
))
list_
atom
s
.
append
(
category_
to_atoms_polarity
(
category
))
list_batch
.
append
(
list_
symbol
s
)
list_batch
.
append
(
list_
atom
s
)
return
list_batch
return
list_batch
This diff is collapsed.
Click to expand it.
SuperTagger/eval.py
+
5
−
1
View file @
a702fd51
...
@@ -3,6 +3,8 @@ from torch import Tensor
...
@@ -3,6 +3,8 @@ from torch import Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.functional
import
nll_loss
,
cross_entropy
from
torch.nn.functional
import
nll_loss
,
cross_entropy
from
SuperTagger.Linker.utils
import
get_atoms_batch
,
find_pos_neg_idexes
class
SinkhornLoss
(
Module
):
class
SinkhornLoss
(
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
...
@@ -19,8 +21,10 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
axiom_links_pred : (batch_size, max_atoms_type_polarity)
axiom_links_pred : (batch_size, max_atoms_type_polarity)
"""
"""
# Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence)
# Convert batch_axiom_links into list of atoms (batch_size, max_atoms_in_sentence)
atoms_batch
=
get_atoms_batch
(
batch_axiom_links
)
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
# then convert into atom_vocab_size lists of (batch_size, max atom in one cat) with prefix parcours of graphe
atoms_polarity
=
find_pos_neg_idexes
(
atoms_batch
)
axiom_links_true
=
""
axiom_links_true
=
""
...
@@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
...
@@ -30,4 +34,4 @@ def mesure_accuracy(batch_axiom_links, axiom_links_pred):
correct_links
[
axiom_links_pred
!=
axiom_links_true
]
=
0
correct_links
[
axiom_links_pred
!=
axiom_links_true
]
=
0
num_correct_links
=
correct_links
.
sum
().
item
()
num_correct_links
=
correct_links
.
sum
().
item
()
return
num_correct_links
return
num_correct_links
/
(
axiom_links_pred
.
size
()[
0
]
*
axiom_links_pred
.
size
()[
1
])
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment