Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
DeepGrail Linker
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
PNRIA
Global Helper
DeepGrail Linker
Commits
b54804c0
Commit
b54804c0
authored
3 years ago
by
Caroline DE POURTALES
Browse files
Options
Downloads
Patches
Plain Diff
starting train
parent
8dc363bd
No related branches found
No related tags found
2 merge requests
!6
Linker with transformer
,
!5
Linker with transformer
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
SuperTagger/Linker/Linker.py
+31
-6
31 additions, 6 deletions
SuperTagger/Linker/Linker.py
SuperTagger/Linker/utils.py
+7
-0
7 additions, 0 deletions
SuperTagger/Linker/utils.py
bash_GPU.sh
+0
-13
0 additions, 13 deletions
bash_GPU.sh
weighting.py
+0
-51
0 additions, 51 deletions
weighting.py
with
38 additions
and
70 deletions
SuperTagger/Linker/Linker.py
+
31
−
6
View file @
b54804c0
...
@@ -9,7 +9,7 @@ from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
...
@@ -9,7 +9,7 @@ from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
from
SuperTagger.Linker.AtomTokenizer
import
AtomTokenizer
from
SuperTagger.Linker.AtomTokenizer
import
AtomTokenizer
from
SuperTagger.Linker.atom_map
import
atom_map
from
SuperTagger.Linker.atom_map
import
atom_map
from
SuperTagger.Linker.Sinkhorn
import
sinkhorn_fn_no_exp
as
sinkhorn
from
SuperTagger.Linker.Sinkhorn
import
sinkhorn_fn_no_exp
as
sinkhorn
from
SuperTagger.Linker.utils
import
find_pos_neg_idexes
,
get_atoms_batch
from
SuperTagger.Linker.utils
import
find_pos_neg_idexes
,
get_atoms_batch
,
mesure_accuracy
from
SuperTagger.Linker.AttentionLayer
import
FFN
,
AttentionLayer
from
SuperTagger.Linker.AttentionLayer
import
FFN
,
AttentionLayer
from
SuperTagger.utils
import
pad_sequence
from
SuperTagger.utils
import
pad_sequence
...
@@ -98,11 +98,36 @@ class Linker(Module):
...
@@ -98,11 +98,36 @@ class Linker(Module):
return
link_weights
return
link_weights
def
predict_axiom_links
(
self
):
def
predict_axiom_links
(
self
,
b_sents_tokenized
,
b_sents_mask
):
return
None
return
None
def
eval_batch
(
self
):
def
eval_batch
(
self
,
batch
,
cross_entropy_loss
):
return
None
b_sents_tokenized
=
batch
[
0
].
to
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
b_sents_mask
=
batch
[
1
].
to
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
b_category
=
batch
[
2
].
to
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
def
eval_epoch
(
self
):
logits_axiom_links_pred
=
self
.
predict_axiom_links
(
b_sents_tokenized
,
b_sents_mask
)
return
None
# Softmax and argmax
axiom_links_pred
=
torch
.
argmax
(
torch
.
nn
.
functional
.
softmax
(
logits_axiom_links_pred
,
dim
=
2
),
dim
=
2
)
accuracy
=
mesure_accuracy
(
b_category
,
axiom_links_pred
)
loss
=
float
(
cross_entropy_loss
(
axiom_links_pred
,
b_category
))
return
accuracy
,
loss
def
eval_epoch
(
self
,
dataloader
,
cross_entropy_loss
):
r
"""
Average the evaluation of all the batch.
Args:
dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
"""
accuracy_average
=
0
loss_average
=
0
compt
=
0
for
step
,
batch
in
enumerate
(
dataloader
):
compt
+=
1
accuracy
,
loss
=
self
.
eval_batch
(
batch
,
cross_entropy_loss
)
accuracy_average
+=
accuracy
loss_average
+=
loss
return
accuracy_average
/
compt
,
loss_average
/
compt
This diff is collapsed.
Click to expand it.
SuperTagger/Linker/utils.py
+
7
−
0
View file @
b54804c0
...
@@ -92,3 +92,10 @@ def find_pos_neg_idexes(batch_symbols):
...
@@ -92,3 +92,10 @@ def find_pos_neg_idexes(batch_symbols):
list_symbols
.
append
(
cut_category_in_symbols
(
category
))
list_symbols
.
append
(
cut_category_in_symbols
(
category
))
list_batch
.
append
(
list_symbols
)
list_batch
.
append
(
list_symbols
)
return
list_batch
return
list_batch
def
mesure_accuracy
(
b_category
,
axiom_links_pred
):
# Convert b_category into
return
0
\ No newline at end of file
This diff is collapsed.
Click to expand it.
bash_GPU.sh
deleted
100644 → 0
+
0
−
13
View file @
8dc363bd
#!/bin/sh
#SBATCH --job-name=N-tensorboard
#SBATCH --partition=RTX6000Node
#SBATCH --gres=gpu:1
#SBATCH --mem=32000
#SBATCH --gres-flags=enforce-binding
#SBATCH --error="error_rtx1.err"
#SBATCH --output="out_rtx1.out"
module purge
module load singularity/3.0.3
srun singularity
exec
/logiciels/containerCollections/CUDA11/pytorch-NGC-21-03-py3.sif python
"train.py"
\ No newline at end of file
This diff is collapsed.
Click to expand it.
weighting.py
deleted
100644 → 0
+
0
−
51
View file @
8dc363bd
from
Configuration
import
Configuration
from
SuperTagger.Symbol.SymbolTokenizer
import
SymbolTokenizer
from
SuperTagger.utils
import
read_csv_pgbar
from
SuperTagger.Symbol.symbol_map
import
symbol_map
from
collections
import
Counter
import
numpy
as
np
import
statistics
file_path
=
'
Datasets/m2_dataset.csv
'
max_symbols_in_sentence
=
int
(
Configuration
.
modelDecoderConfig
[
'
max_symbols_in_sentence
'
])
max_len_sentence
=
int
(
Configuration
.
modelDecoderConfig
[
'
max_len_sentence
'
])
df
=
read_csv_pgbar
(
file_path
)
all_symbols
=
[
item
for
sublist
in
list
(
df
[
'
sub_tree
'
])
for
item
in
sublist
]
counter
=
Counter
(
all_symbols
)
print
(
counter
)
number_total_symbols
=
len
(
all_symbols
)
print
(
number_total_symbols
)
most_common_symbol
,
max_number_in_one_symbol
=
counter
.
most_common
(
1
)[
0
]
print
(
most_common_symbol
)
print
(
max_number_in_one_symbol
)
middle_common_symbol
,
middle_number_in_one_symbol
=
counter
.
most_common
(
6
)[
5
]
print
(
middle_common_symbol
)
print
(
middle_number_in_one_symbol
)
mean
=
statistics
.
mean
(
counter
.
values
())
print
(
mean
)
def
get_weight
(
count_symbol_x
,
count_symbol_threashold
):
x
=
count_symbol_threashold
/
count_symbol_x
return
1
+
np
.
log
(
x
+
1
)
**
2
dic_symbols_weights
=
{}
for
(
key
,
value
)
in
counter
.
items
():
dic_symbols_weights
[
key
]
=
np
.
round
(
get_weight
(
value
,
mean
),
4
)
print
(
dic_symbols_weights
)
list_ordered
=
[]
for
symbol
in
symbol_map
.
keys
():
if
symbol
!=
'
[START]
'
and
symbol
!=
'
[PAD]
'
:
list_ordered
.
append
(
dic_symbols_weights
[
symbol
])
print
(
list_ordered
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment