Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DeepGrail Linker
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
PNRIA
Global Helper
DeepGrail Linker
Commits
8b0f5bb5
Commit
8b0f5bb5
authored
3 years ago
by
Caroline DE POURTALES
Browse files
Options
Downloads
Patches
Plain Diff
adding comments
parent
f996b207
No related branches found
No related tags found
2 merge requests
!6
Linker with transformer
,
!5
Linker with transformer
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
SuperTagger/Linker/Linker.py
+6
-6
6 additions, 6 deletions
SuperTagger/Linker/Linker.py
SuperTagger/Linker/utils.py
+39
-8
39 additions, 8 deletions
SuperTagger/Linker/utils.py
SuperTagger/eval.py
+2
-2
2 additions, 2 deletions
SuperTagger/eval.py
with
47 additions
and
16 deletions
SuperTagger/Linker/Linker.py
+
6
−
6
View file @
8b0f5bb5
...
@@ -62,13 +62,13 @@ class Linker(Module):
...
@@ -62,13 +62,13 @@ class Linker(Module):
)
)
def
forward
(
self
,
atoms_batch_tokenized
,
atoms_polarity_batch
,
sents_embedding
):
def
forward
(
self
,
atoms_batch_tokenized
,
atoms_polarity_batch
,
sents_embedding
):
'''
r
'''
Parameters :
Parameters :
c
at
egory
_batch
: batch of size (batch_size, sequenc
e_
l
en
gth) = output of decoder
at
oms
_batch
_tokenized : (batch_size, max_atoms_in_on
e_
s
en
tence) flattened categories
sents_embedding
atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
sents_
mask
sents_
embedding : output of BERT for context
Ret
t
urns :
Returns :
link_weights :
batch-size, atom_vocab_size, ...
)
link_weights :
atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat
)
'''
'''
# atoms embedding
# atoms embedding
...
...
This diff is collapsed.
Click to expand it.
SuperTagger/Linker/utils.py
+
39
−
8
View file @
8b0f5bb5
...
@@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
...
@@ -16,6 +16,14 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
def
get_axiom_links
(
max_atoms_in_one_type
,
atoms_polarity
,
batch_axiom_links
):
def
get_axiom_links
(
max_atoms_in_one_type
,
atoms_polarity
,
batch_axiom_links
):
r
'''
Parameters :
max_atoms_in_one_type : configuration
atoms_polarity : (batch_size, max_atoms_in_sentence)
batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
Returns :
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
'''
atoms_batch
=
get_atoms_links_batch
(
batch_axiom_links
)
atoms_batch
=
get_atoms_links_batch
(
batch_axiom_links
)
linking_plus_to_minus_all_types
=
[]
linking_plus_to_minus_all_types
=
[]
for
atom_type
in
list
(
atom_map
.
keys
())[:
-
1
]:
for
atom_type
in
list
(
atom_map
.
keys
())[:
-
1
]:
...
@@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
...
@@ -37,6 +45,13 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
def
category_to_atoms_axiom_links
(
category
,
categories_to_atoms
):
def
category_to_atoms_axiom_links
(
category
,
categories_to_atoms
):
r
'''
Parameters :
category
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
'''
res
=
[
bool
(
re
.
match
(
r
''
+
atom_type
+
"
_\d+
"
,
category
))
for
atom_type
in
atom_map
.
keys
()]
res
=
[
bool
(
re
.
match
(
r
''
+
atom_type
+
"
_\d+
"
,
category
))
for
atom_type
in
atom_map
.
keys
()]
if
category
.
startswith
(
"
GOAL:
"
):
if
category
.
startswith
(
"
GOAL:
"
):
word
,
cat
=
category
.
split
(
'
:
'
)
word
,
cat
=
category
.
split
(
'
:
'
)
...
@@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
...
@@ -52,6 +67,11 @@ def category_to_atoms_axiom_links(category, categories_to_atoms):
def
get_atoms_links_batch
(
category_batch
):
def
get_atoms_links_batch
(
category_batch
):
r
"""
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch
=
[]
batch
=
[]
for
sentence
in
category_batch
:
for
sentence
in
category_batch
:
categories_to_atoms
=
[]
categories_to_atoms
=
[]
...
@@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch):
...
@@ -67,6 +87,13 @@ def get_atoms_links_batch(category_batch):
def
category_to_atoms
(
category
,
categories_to_atoms
):
def
category_to_atoms
(
category
,
categories_to_atoms
):
r
'''
Parameters :
category
categories_to_atoms : recursive list
Returns :
List of atoms inside the category in prefix order
'''
res
=
[
bool
(
re
.
match
(
r
''
+
atom_type
+
"
_\d+
"
,
category
))
for
atom_type
in
atom_map
.
keys
()]
res
=
[
bool
(
re
.
match
(
r
''
+
atom_type
+
"
_\d+
"
,
category
))
for
atom_type
in
atom_map
.
keys
()]
if
category
.
startswith
(
"
GOAL:
"
):
if
category
.
startswith
(
"
GOAL:
"
):
word
,
cat
=
category
.
split
(
'
:
'
)
word
,
cat
=
category
.
split
(
'
:
'
)
...
@@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms):
...
@@ -84,6 +111,11 @@ def category_to_atoms(category, categories_to_atoms):
def
get_atoms_batch
(
category_batch
):
def
get_atoms_batch
(
category_batch
):
r
"""
category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
(batch_size, max_atoms_in_sentence) flattened categories in prefix order
"""
batch
=
[]
batch
=
[]
for
sentence
in
category_batch
:
for
sentence
in
category_batch
:
categories_to_atoms
=
[]
categories_to_atoms
=
[]
...
@@ -98,9 +130,9 @@ def get_atoms_batch(category_batch):
...
@@ -98,9 +130,9 @@ def get_atoms_batch(category_batch):
#########################################################################################
#########################################################################################
def
category_to_atoms_polarity
(
category
,
polarity
):
def
category_to_atoms_polarity
(
category
,
polarity
):
'''
r
'''
Parameters :
Parameters :
category : str of kind AtomCat | CategoryCat
category : str of kind AtomCat | CategoryCat
(dr or dl)
Returns :
Returns :
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
'''
'''
...
@@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity):
...
@@ -183,13 +215,12 @@ def category_to_atoms_polarity(category, polarity):
def
find_pos_neg_idexes
(
max_atoms_in_sentence
,
atoms_batch
):
def
find_pos_neg_idexes
(
max_atoms_in_sentence
,
atoms_batch
):
'''
r
"""
Parameters :
max_atoms_in_sentence : configuration
batch_symbols : (batch_size, sequence_length) the batch of symbols
atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
Returns :
Returns :
(batch_size, max_
symbol
s_in_sentence)
boolean tensor indiating pos and ne indexes
(batch_size, max_
atom
s_in_sentence)
flattened categories
'
polarities in prefix order
'''
"""
list_batch
=
[]
list_batch
=
[]
for
sentence
in
atoms_batch
:
for
sentence
in
atoms_batch
:
list_atoms
=
[]
list_atoms
=
[]
...
...
This diff is collapsed.
Click to expand it.
SuperTagger/eval.py
+
2
−
2
View file @
8b0f5bb5
...
@@ -20,8 +20,8 @@ class SinkhornLoss(Module):
...
@@ -20,8 +20,8 @@ class SinkhornLoss(Module):
def
mesure_accuracy
(
batch_true_links
,
axiom_links_pred
):
def
mesure_accuracy
(
batch_true_links
,
axiom_links_pred
):
r
"""
r
"""
batch_
axiom
_links : (batch_size,
...)
batch_
true
_links : (batch_size,
atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size,
max_atoms_type_polarity)
axiom_links_pred : (batch_size,
atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
"""
"""
correct_links
=
torch
.
ones
(
axiom_links_pred
.
size
())
correct_links
=
torch
.
ones
(
axiom_links_pred
.
size
())
correct_links
[
axiom_links_pred
!=
batch_true_links
]
=
0
correct_links
[
axiom_links_pred
!=
batch_true_links
]
=
0
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment