From 85e492e2dee7654dc35baee0872ef0ffe932fbd1 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Thu, 5 May 2022 13:53:04 +0200
Subject: [PATCH] starting train

---
 SuperTagger/Decoder/RNNDecoderLayer.py        | 180 --------------
 .../RNNDecoderLayer.cpython-38.pyc            | Bin 5324 -> 0 bytes
 SuperTagger/Encoder/EncoderInput.py           |  18 --
 SuperTagger/Encoder/EncoderLayer.py           |  67 -----
 .../__pycache__/EncoderInput.cpython-38.pyc   | Bin 1182 -> 0 bytes
 .../__pycache__/EncoderLayer.cpython-38.pyc   | Bin 2675 -> 0 bytes
 SuperTagger/Linker/AtomTokenizer.py           |  23 +-
 SuperTagger/Linker/Linker.py                  |  41 ++--
 SuperTagger/Linker/Sinkhorn.py                |   1 -
 .../__pycache__/Sinkhorn.cpython-38.pyc       | Bin 687 -> 687 bytes
 SuperTagger/Linker/utils.py                   |   1 -
 SuperTagger/Symbol/SymbolEmbedding.py         |  12 -
 SuperTagger/Symbol/SymbolTokenizer.py         |  53 ----
 .../SymbolEmbedding.cpython-38.pyc            | Bin 858 -> 0 bytes
 .../SymbolTokenizer.cpython-38.pyc            | Bin 2862 -> 0 bytes
 .../__pycache__/symbol_map.cpython-38.pyc     | Bin 563 -> 0 bytes
 SuperTagger/Symbol/symbol_map.py              |  28 ---
 SuperTagger/eval.py                           |  12 +-
 SuperTagger/utils.py                          |  20 ++
 test.py                                       |  59 +++--
 train.py                                      | 232 ++++--------------
 21 files changed, 150 insertions(+), 597 deletions(-)
 delete mode 100644 SuperTagger/Decoder/RNNDecoderLayer.py
 delete mode 100644 SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc
 delete mode 100644 SuperTagger/Encoder/EncoderInput.py
 delete mode 100644 SuperTagger/Encoder/EncoderLayer.py
 delete mode 100644 SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc
 delete mode 100644 SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc
 delete mode 100644 SuperTagger/Symbol/SymbolEmbedding.py
 delete mode 100644 SuperTagger/Symbol/SymbolTokenizer.py
 delete mode 100644 SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc
 delete mode 100644 SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc
 delete mode 100644 SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc
 delete mode 100644 SuperTagger/Symbol/symbol_map.py

diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py
deleted file mode 100644
index 93e96a6..0000000
--- a/SuperTagger/Decoder/RNNDecoderLayer.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import random
-
-import torch
-import torch.nn.functional as F
-from torch.nn import (Module, Dropout, Linear, LSTM)
-
-from Configuration import Configuration
-from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding
-
-
-class RNNDecoderLayer(Module):
-    def __init__(self, symbols_map):
-        super(RNNDecoderLayer, self).__init__()
-
-        # init params
-        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
-        self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
-        dropout = float(Configuration.modelDecoderConfig['dropout'])
-        self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
-        self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing'])
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.symbols_vocab_size = int(Configuration.datasetConfig['symbols_vocab_size'])
-
-        self.bidirectional = False
-        self.use_attention = True
-        self.symbols_map = symbols_map
-        self.symbols_padding_id = self.symbols_map["[PAD]"]
-        self.symbols_sep_id = self.symbols_map["[SEP]"]
-        self.symbols_start_id = self.symbols_map["[START]"]
-        self.symbols_sos_id = self.symbols_map["[SOS]"]
-
-        # Different layers
-        # Symbols Embedding
-        self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size,
-                                                padding_idx=self.symbols_padding_id)
-        # For hidden_state
-        self.dropout = Dropout(dropout)
-        # rnn Layer
-        if self.use_attention:
-            self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                            dropout=dropout,
-                            bidirectional=self.bidirectional, batch_first=True)
-        else:
-            self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers,
-                            dropout=dropout,
-                            bidirectional=self.bidirectional, batch_first=True)
-
-        # Projection on vocab_size
-        if self.bidirectional:
-            self.proj = Linear(self.dim_encoder * 2, self.symbols_vocab_size)
-        else:
-            self.proj = Linear(self.dim_encoder, self.symbols_vocab_size)
-
-        self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence)
-        self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder)
-
-    def sos_mask(self, y):
-        return torch.eq(y, self.symbols_sos_id)
-
-    def forward(self, symbols_tokenized_batch, last_hidden_state, pooler_output):
-        r"""Training the translation from encoded sentences to symbols
-
-        Args:
-            symbols_tokenized_batch: [batch_size, max_len_sentence] the true symbols for each sentence.
-            last_hidden_state: [batch_size, max_len_sentence, dim_encoder]  Sequence of hidden-states at the output of the last layer of the model.
-            pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        batch_size, sequence_length, hidden_size = last_hidden_state.shape
-
-        # y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1
-        y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
-                            dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-        y_hat[:, :, self.symbols_padding_id] = 1
-
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
-                               device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
-
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        # hidden_state goes through multiple linear layers
-        hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
-
-        c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        use_teacher_forcing = True if random.random() < self.teacher_forcing else False
-
-        # for each symbol
-        for i in range(self.max_len_sentence):
-            # teacher-forcing training : we pass the target value in the embedding, not a created vector
-            symbols_embedding = self.symbols_embedder(decoded_i)
-            symbols_embedding = self.dropout(symbols_embedding)
-
-            output = symbols_embedding
-            if self.use_attention:
-                attn_weights = F.softmax(
-                    self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
-                attn_applied = torch.bmm(attn_weights, last_hidden_state)
-
-                output = torch.cat((symbols_embedding, attn_applied), 2)
-                output = self.attn_combine(output)
-                output = F.relu(output)
-
-            # rnn layer
-            output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
-
-            # Projection of the output of the rnn omitting the last probability (which is pad) so we dont predict PAD
-            proj = self.proj(output)[:, :, :-2]
-
-            if use_teacher_forcing:
-                decoded_i = symbols_tokenized_batch[:, i].unsqueeze(1)
-            else:
-                decoded_i = torch.argmax(F.softmax(proj, dim=2), dim=2)
-
-            # Calculate sos and pad
-            sos_mask_i = self.sos_mask(torch.argmax(F.softmax(proj, dim=2), dim=2)[:, -1])
-            y_hat[~sos_mask, i, self.symbols_padding_id] = 0
-            y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :]
-            sos_mask = sos_mask_i | sos_mask
-
-            # Stop if every sentence says padding or if we are full
-            if not torch.any(~sos_mask):
-                break
-
-        return y_hat
-
-    def predict_rnn(self, last_hidden_state, pooler_output):
-        r"""Predicts the symbols from the output of the encoder.
-
-        Args:
-            last_hidden_state: [batch_size, max_len_sentence, dim_encoder] the output of the encoder
-            pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        batch_size, sequence_length, hidden_size = last_hidden_state.shape
-
-        # contains the predictions
-        y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size,
-                            dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-        y_hat[:, :, self.symbols_padding_id] = 1
-        # input of the embedder, a created vector that replace the true value
-        decoded_i = torch.ones(batch_size, 1, dtype=torch.long,
-                               device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id
-
-        sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1)
-
-        c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size,
-                              dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu")
-
-        for i in range(self.max_len_sentence):
-            symbols_embedding = self.symbols_embedder(decoded_i)
-            symbols_embedding = self.dropout(symbols_embedding)
-
-            output = symbols_embedding
-            if self.use_attention:
-                attn_weights = F.softmax(
-                    self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2)
-                attn_applied = torch.bmm(attn_weights, last_hidden_state)
-
-                output = torch.cat((symbols_embedding, attn_applied), 2)
-                output = self.attn_combine(output)
-                output = F.relu(output)
-
-            output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state))
-
-            proj_softmax = F.softmax(self.proj(output)[:, :, :-2], dim=2)
-            decoded_i = torch.argmax(proj_softmax, dim=2)
-
-            # Set sos and pad
-            sos_mask_i = self.sos_mask(decoded_i[:, -1])
-            y_hat[~sos_mask, i, self.symbols_padding_id] = 0
-            y_hat[~sos_mask, i, :-2] = proj_softmax[~sos_mask, -1, :]
-            sos_mask = sos_mask_i | sos_mask
-
-            # Stop if every sentence says padding or if we are full
-            if not torch.any(~sos_mask):
-                break
-
-        return y_hat
diff --git a/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc b/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc
deleted file mode 100644
index cd9f43c2b2c435f545d90c5f77a9e9e4ce5480f2..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 5324
zcmWIL<>g{vU|?AIAUi2bnStRkh=Yt-7#J8F7#J9e^B5QyQW#Pga~Pr^G-DJan9r2U
z%*4RRkjoOq$_SES%3;f8k7Cc|h~miQjN;7YisH)Uj^fVciQ)mvGw1N;@<#E3*(^DH
zx%^T5j12A!DXb}MEet7asT|GBQ35H9!3>)0FF`Kx(`38F?oyOrkY8GIi_JGbr8Fn?
z78iu}$t*6p#paWlmzr2~i^V56#P=4ndr+t*<1Jq2{JgZx^wOfllFa-(O~zaN!Iim5
z`8lq+NvSC*nR)5SAhVG%Gn7+o!oa|g$^i0R6jM7x8e<A$3R4S56mtr53QG$^6br<q
zQLHKKDI6^fQEcrDEDTZX!3>(5xA=qn{9IC#^HWlbd=e{Di;|h3hJq-N1sn_v3?RCA
z5+egc2}22E4MQ{20;Yuwj0`2rB`jI2DU6Z~&5Sh+@oZ2&6O_*m<ugP198f+Bl+Ov}
zvqJe?CEQ^1*g*VRCa5esh*!f9&jXd^fbw}$IBOW<`BJ!Q7~=VBK=upNFk}g&a7!}O
zFvJUj**svj5SYyiW($Mad|<XniD(T&7JCgt7FP{JmS75hFLMn;yjY2N4HMWu5;Y82
z{1X_9#7ZP<7_!))n5TxRh9QeDg)2p%m#u~&UMf!xEDtgVqy|GCq<R8lQBa9A$b=L@
z5eATr%}g~6@iKV=5K)8(m=88f9HfpDtWFk9lnpE@m%<Rtpef{ci#sJVH$F8F9ElK4
z3M58viRUI(#215;QE_}`UVL$CUP)?Ra_TJskVH-@l86vYSy_H^Vp4o@W>xAf_7rF$
z=F2P1jW5c}i_Zb2k>XqYC8>$Y8L36_Y57ITnR)4MAyur=0gf)QRjkp$t^u)C?9ss?
zjzJ+{zJGA6CgUycf<$mCj?YY~&}6&Cm6=yiS^}|#J0mkC1>_cxom?QRAl4KKFfcIO
z;!Vm-$t+4u2IY#xoLk&Ui6zMy@oAYw#U(F6iB(hn7He^7L26MEs9d<k5g(tKmst`Y
ze~URYucQc+d~XTm=BK3QK$AHnn-y_0FfbJHfCyd?!3QGvLH4nx<>V)p6bXU^gg}Hi
zNR+p<I5j@8q$D*D<c?e1u+Yj)EVu;^I)oRa_+Sc(Qw!oVQ*H^s`6Y=(B_I)miu~gE
z%#<QdkTq~+so-3nT2ut`A~+z5*gy)Hi}Lb{SV1h7f};GaTP%qsC3&}aKxBM!er{4`
zUaBVZEtcZcoU|eZkn2HVxRUV}XMB8ePGWI!eEdp=Utand`MIh3$tkG?`K3iA`c9=q
zsfnffDXFOi=|zc|Iq^Y$e)_?nFbzpePfsn<hX$QKT4AAAP+26+z`y_s#$pdpiNVOn
zEWikZd>|GR8zTtvFoGcq7bDAm4rVSU4n{skCPo$prhi;4Fct?Z2csC10FxMFl?1v?
zuvkSYpFm*)D(;;@#XTqnOBfa~g35v{riDzw44RC7n#{LYOY)17Gj1`Z78WTpFfeE`
z6={NkkFl~y2NXVFLKh^$QJh~KpPN{mZ3Qw8WHtj9hrkpigDeJl8H7a`7#KjQ5A2XK
zCI*HYrYwdS=316o)>^g_#w?~9mKugE<{E}9mJ-G+))K}nwgv198H#jL81tA?m{OQ~
zS!>x-SZdf)SW{SAnIL2hLl#FGvkgNDV-{x(OA1IQ7hGo=Tqnq6ZV`rB4v0M^j9ENz
zRXuQ3%(a{~%nNu^*cLJ_-~$zGDIk3U;tV1TDJ&^03z=&{W`l}0euPX7OAW~N2>0HB
zo5xzqT_U(Za3QFW63S*M`j*0;$CScR%Tod>Vwsy6T^J@X#%9&>)-Ws(so_}2Sj$(#
zut2ni56t4L<*#AL5}UwS)KVgzB~in$K(d5!fm97&4Oa@AB*OyU6qbdIX-p}c3z-%&
zrf{Zk^)l58lrS!kt`VqVTF6u@2vsczQ3En<flLkmLdIJD67dDH5LOMt0=XLg8m=1d
zG^T}2F-*1mF-)~wwcIrf3*<p6AfXH=xm%gim?ar%g=%;vFc#gaVT8!k@TM`rTvW@E
zCsD$fr2uxb5X4;&vX-NUVSyrcIfywaJRoxxGS&*Gu+^~FaPl+M@WNuaMhL`D;en_G
zxdNslg*652E*pj#;TldShFZ=N#sx|s)e9LXFcz~gGE879ESkVr$WjBU&v_FKm>3vB
ziV`#PGV{_EN-|OvN{SNmigUni3Wc<y{9FY{MUkQas|$)1O7az;r4tvI0u(qFr59Vl
zc_2ZklK7JR?9{x>s??Nta7k>X5DlimMY)aws9u4oM5=INp>8Qng;}MLmS3a*szu-q
z(L=E=C$YFB9#%1wB$lLNu?V>W(@{XG1!ENyf>R4iK~<eXewqTr7~Nu!F~tgrC1A(o
zmzETimVl%|!345S0bK3CM8I_hs+$V(^K(**;vqWG+=F7fPhxQi)HGzfVMc>1F9mQw
z!i0)p&eZ@H)0t_R$&fe$tI||ROe;w(Qb;Q;D#=JKQYa|OPfjf^hQx7Eerb9J*h^sd
z7b}z&r>1~|7v!<T(u&NS%*3Kfg@U5g5_rl;EY5}p_bry>(v-wo%*h3%E17Pwrj%3`
zq~2mnNiEAvPJPM9z`)>F1R_J+iX<5r7*;agVou4-efjVI|Nk%l|NsAAQ~4HaaYkZ6
zYLOnO7-g+WEy^!00@tdb3bjZMRLmNIq<Aum;}gpgGjkG?a#D*FK|(C~d8x&>SaS07
z(u<V90!jJ#Ikz}V^NI^gQ&X!_Z?P4n7No)|WN=-~R+N~RlAl`ys-17K7A5ATrxvMz
zECJPZMe-n17~O8M7w4yy<R(@Wse=TVlM_pBF(>8b7HNREEJdj~rA4YBO==)MTVhc<
zNUb4=!<?8`sVM@fCd6Th85|FYWPeK#lwwd4=q+B9^l^&|kuPrX!{RqSCp9m<B%??a
zWD#p+d`4o)Elx;tEhRp)2y8zOB+bNw)9WqvWC$Zl7}NwoYk1va%)BKCb0Vb8keQc$
ziw%-Qq3we5)XemZl42+?v7jI)Gc~0M>~AhubrzqAULCTh<rkGF7NyvNYQr#4Eyuyb
z%*V*WAn>1!NrH)oQQ$udBg=mlW(8&iCNYROvlz1g6Bi>F6AL2;BMYM(1Ji#tCLSgM
zMlMD!W(j5mCN@S1CJAOOMj=KOMh<2^MlPl*arCMlYHu>A*8#E?gh5dauVA%66)YpT
zf(2Kx;EE4axq>jLnnYAs$Q3MG3QG+eq=<!(u;Le7!7>x9VA&z|;H+RdYMAg=u$*B1
z3|Yt(EOQOWtso3(5Fl5uTqU4h46G_O0asi?HS7x+Yk6uI76{kyfLT1CDmF`G0%K8L
ziD;Hs4Z{NQ62=7*H9R$(C{-ytxGH5&;pk<m<tt%aAX&p#!?=*CmLIB`KaDAceIXOb
zv;|T%ybBp?c}qkWNJCgP3=3pxcxyOoxS&-jZwyl{XDt`hMp=+5NSMM&PDphM>T*IV
z?<~0*0Vt0hTm{1#J6zC;Q80zAhOLH!pP_~u7Be*hAbtuL#8hy_2vw27ngX`LhM`8V
zhQkTm{8^v?QoWF|_#L=fR4aT}!xqe-$(`84$iNU#l$w&6Tv7}!D4{t9QXzs$Nl1Qy
z6~>S%vPchKb%>Gu(dq+GMTn<<!f7eIR0h?#C<O|d1w**PDFW37MaH1g7c_KHWDc)m
zz%@*f1xUsUL|B7b?21J;AU3FWDFT&7*y@%dJCG7kl~QC6Vu410z!gf7BZ%t+B0#P4
zB5=PBRE-q5g2YhkjUqRYm^+97)fz>hK24D)hzqJUio8H9c%@O~4dQ}I>>?kKGljr4
zMv)&#)E`6yfQUd40V?f_f<UZb5CPT`0%C=N2vFf#6b@oVfQU#C5yim35XA%P;>E-2
zv1pJuxVr)-z<n0(0!Xq3b<P_=1vSV{26h|;H4~zs1{cZLifJYZCM`xjW-(|H&BMsS
zR3(jGL?hX%$pPwe=OyN*#>d~{ijU6)4|m1K-{OgnFDy;WfyjWniQw9;C<f$>SWxxH
z3+<C8r{?6u$0OB0S|BAL|AD*opq_d$s2|V4$fL)>%;5|c)f6mp0VxBeid(|qKAv7)
zo?cpM9&Dt5DKGC97lIFt+gltkm5`1#sE#QD_2j^PYH%03h#h1s$i65gq&~17WHb&+
p!3X8Q-FC2pklc@m9}XKxc-Vm&P{p8rG6xd}qY$G46AvQ~GXR!;l0^Uj

diff --git a/SuperTagger/Encoder/EncoderInput.py b/SuperTagger/Encoder/EncoderInput.py
deleted file mode 100644
index e19da7d..0000000
--- a/SuperTagger/Encoder/EncoderInput.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-
-class EncoderInput():
-
-    def __init__(self, tokenizer):
-        """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """
-        self.tokenizer = tokenizer
-
-    def fit_transform(self, sents):
-        return self.tokenizer(sents, padding=True,)
-
-    def fit_transform_tensors(self, sents):
-        temp = self.tokenizer(sents, padding=True, return_tensors='pt', )
-        return temp['input_ids'], temp['attention_mask']
-
-    def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False):
-        return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens)
diff --git a/SuperTagger/Encoder/EncoderLayer.py b/SuperTagger/Encoder/EncoderLayer.py
deleted file mode 100644
index c954584..0000000
--- a/SuperTagger/Encoder/EncoderLayer.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import sys
-
-import torch
-from torch import nn
-
-from Configuration import Configuration
-
-
-class EncoderLayer(nn.Module):
-    """Encoder class, imput of supertagger"""
-
-    def __init__(self, model):
-        super(EncoderLayer, self).__init__()
-        self.name = "Encoder"
-
-        self.bert = model
-
-        self.hidden_size = self.bert.config.hidden_size
-
-    def forward(self, batch):
-        r"""
-        Args :
-            batch: list[str,mask], list of sentences (NOTE: untokenized, continuous sentences)
-        Returns :
-                last_hidden_state: [batch_size, max_len_sentence, dim_encoder]  Sequence of hidden-states at the output of the last layer of the model.
-                pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task
-        """
-        b_input_ids = batch[0]
-        b_input_mask = batch[1]
-
-        outputs = self.bert(
-            input_ids=b_input_ids, attention_mask=b_input_mask)
-
-        return outputs[0], outputs[1]
-
-    @staticmethod
-    def load(model_path: str):
-        r""" Load the model from a file.
-        Args :
-            model_path (str): path to model
-        Returns :
-            model (nn.Module): model with saved parameters
-        """
-        params = torch.load(
-            model_path, map_location=lambda storage, loc: storage)
-        args = params['args']
-        model = EncoderLayer(**args)
-        model.load_state_dict(params['state_dict'])
-
-        return model
-
-    def save(self, path: str):
-        r""" Save the model to a file.
-        Args :
-            path (str): path to the model
-        """
-        print('save model parameters to [%s]' % path, file=sys.stderr)
-
-        params = {
-            'args': dict(bert_config=self.bert.config, dropout_rate=self.dropout_rate),
-            'state_dict': self.state_dict()
-        }
-
-        torch.save(params, path)
-
-    def to_dict(self):
-        return {}
diff --git a/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc
deleted file mode 100644
index 03717155496997dc3ef4713b269d7115fda65fd7..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 1182
zcmWIL<>g{vU|?9;=bKc)#K7<v#6iZ)3=9ko3=9m#QVa|XDGVu$ISf$@?hGkRDa<Vl
zDa_4GQH&{!!3>%#FG0Hel0hUghU?d3U|>jPh+<4(h+;}%Okrwah+<A*PGM<bh+;`$
zO<`+ch+<9QPT@%5Y+;OIOW_J;(B!$r<C>S8pORYSnO9I+lFWo;Fo+Gp&LFpeyimhf
z!w}B^b_-LLyF)=@QDSbfLP>shYF=hlYLS9QKv8N*QDSCZYDx%9P}52Q&Q>VNNGyS?
zFIFf?tte6OTgiBf6K;kk(=C?b)SR>;7Ep*W6tOWdF#K}Y&&bbB)lW`IEyyn|D$#c;
zElN!+)lW%HEl4j)%*=@o^7GRVE-gqc3Q0^)Pc71idI?5@y{1=Cd5a@HJ~J<~Bt9Nw
zY%$1h3`|v`P}LAsFclaf0df#HB={H@7-|@@7-|?nVcg3U%%I8a7oy2{i@hK*B_%U2
zy$BQ`noLEk3=9mnSc_BhN{Ye$fDr5;QQox7lK7IM#Ju9P{Gwb?*g{;xSS5wx8Wd$H
zP6F8hvH<KP1r#TxF!wUoGSx7put+l0Fr~0cGNgeV$>tYwi>aVQlc|Uk<OIH=)RNMo
zy!evTyyE<#;#-`Vpac+~nNoa<FR`Q~HLoNyKQBHvu{ayzA7-$Bia>6?#Zr=*TLAG6
zC?t!(eiB7-EX-V9P~riF9Rs5PW0ewa-@!DvC4;;MONJmeNH^G<0#I*eG1f5FFr_ep
z!=A}clkt{tadu`wd~rc)a%N&qJUBHKYchco(Jh{&#FFHU_>@#oGS`H7BZ>>`+hUM+
zi-Z^$APEiPBZy~`^YhA5i%LKW;!E<OMhb#F15P&RUcsu|Pm|G2lO2@m@)C1X<Ku5}
z#mDF7r<CS^*gWy^g{6r(5SbzlkQ=x`1jvpe5fBR;v<L#^t|E}-pztULY2aYuVBuiq
iU;&G2GTmY=$uCOI0P`STLDhVV!v<oz9V5&i9Lxat0Ui+m

diff --git a/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc
deleted file mode 100644
index f685148ba6d4ea28a6f741836205060a3d60123d..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 2675
zcmWIL<>g{vU|=xO$V~dm&cN^(#6iX^3=9ko3=9m#b_@&*DGVu$ISf${nlXwog&~D0
zhcTBaiir^<#+<{P%N)hb$l%V9!ji(;!jQt6%9O>@%pAp<!Whh;$@UUtvY#g7EvCFY
zO~zZi&iQ$1ndzlPi6xo&dC4G2WDK$z#4iqDU|>jP09hBsl)@Or+|H23n8K98+`<vX
z0<kEHHI*%uJ%w!!V+wl;M+-|cV-!a!dkW_q#uTm;?iQ9P&J>;$-WG-^u670%hA8e}
z22H+OJg#}k`6;PIK8cm7MODgBo<ed?VsWvKLS}A3X^BF9nnH1DL26M+VtRUNQ8E+A
zJx~l{GczzSfN-%E0|P?|LkVLILo?F?riBcQ3@MB?4Drl0Ad;nqA&Vu8HH)o=A)YOT
zA(%mv$?q0B)ILr2TdZI!ia`Fo#StH$nU`4-AAgG_FEKau7E4lUQOPZ~WN=8`;?BrS
zNlDF%FV3t=)nvNGQk<HTc8fJPKP5G1CF3p5_;|46<KtH{{BqaN$j?pHPfkfK$S*A_
z(RV5>N=+=)Pf1NJNH0pv%!v>3^V1Io`7H$OH+`skU^F;X^a?7A*cliYct9zD1LSc=
zK1MbsHl`|3sB(xhn4?hwh?RkX0purVkY+sw28J5O6h=vgTBaJt6edZATIL#tEQT7U
z8s-${UZz@>8Wxaj4GTyvjVYKxlO>UXiGhJj0SX+8(u)<W;9QVsQesJRhLu82W^qY$
zaY>O*ZenqEtPYqDj=9vllGME9)M5n<KmQO{D}~a$lKkw{yv(Z96di@+{JfIPywd#A
zVufOuLQS~wL8&FBMS18Jf-KEREG~(M#Y0JANvf4XG}s~FSkY0)O{|E|0mTv2P#uMo
z%-r}?NLq_kPzX*fECsm=<U)usy2T)4iWL$|6iPBu74l0<ph*wJ16ik#14^PG4ImM4
zbnC%g3UU(&7v$&Xq!z_PbfdWm#RWcz#U)S^k?n^WpO#rvT%rIDD40+&%*h(y^p%;G
znGDXd5LKEAiD@ONMG9%9MI{-jMG6H)`N^rp#hH2O3MCmu`K9R@U~hr_P^?f|oSLGL
zmR|%GODwI(%*jkFs#GW_N-Zf$%*+GpN-WNXd#{L*fq~&Ah|pxZ#hIB`P+Ah7nNoa<
zFR`Q~H4jv9#Dk(;ld*`Cfq_AjwFs0}Z?S@7B#JvJ9<C#b2gU)Zzr_v-sNy041_lNQ
zA;`eMaEm=Hzo<O1C`AO6YWYA(o`ZvtgNcQaiGk@q8*`NiO2UUK)?~cJ1&)^Zg2a-H
zWRPn?Izbqe13(y@3(Ua9L>6NWLkeRGQwsw)<uleWrm%qWMF~?HLk&|OV=Z$H^8)4?
z#uQdbhJ}ob3^mN4!n10vf=_;83OL-rb}6J4<>x9SDx_uRq#}|7mTUr64R)D=MsZ1z
zrj-K7l?o;K3Sb#{P-9Q?U^NOFd3k!i`6;D2shU<0Z3^X?B^e6EiDjuN3I&NpiMgpI
zsYS(b!%&I>P+18IlweSxf<mK)A(&w$qaQfjHJNU)7nkH0C8np|V$R7=1`9$6aEx%+
z<Rs=Mr6k(vGB7ZF28AiSu!fXAIr)hxdNw)v$%#3|c6u<?nvA!2auW;ULE@lfa*HLg
zD82X=7dW%Ur(`CVXtLa5Ey*uR&bY+_G6GTn^Mjc2$cl<2K<R=F?8o9Fc~ElU1rewT
zr$~u`fk6f2Kae;B2Pna?{byt1VB})tW2zEFNj4xmG#PKPfZ_oqpg@%gD15-F1ytph
zfGQG3F@_SxET$BuUM5C{5>UCxlEqrXki`Z{KQ#;s*cURSu*_kqWh&uVz*)mo!?+Ms
zA+v&0O^iZtVp%Fuswv4QCe46chMrpBg%3QEt7Jj(4^0o?00JihP|}K4Esh0w5tKSL
znWDHs^;kT#8skYR$}h+-Er~BmEJ@X5f`q980|SGfCf6<2f}+g4l3UEhmBqK%ic3;b
zi;8ZsfT9bW!HYo6l_GUe_=7YQDT3mL1>_1%W^mRoQU%F@A_9_xK(SE-N=HRHpmq)@
zf<Q(vunI8?FtYt)V?~P^kWnZZ7?f^77#t;_xXEM)W+;*b*#a_KlM!Mqhz+q3WZ5nD
zl6-I)a{^fiO8yLtXhjE9fhG^A&Ii@?@$t8~;^T9{nJhm37EgS9VQFFxRE9l1J|#an
zJ|0{h7m0&>%>zoLnaR1SB^miCx0s7dioj_Q?9U=lUb)2!Z4e}<=H$f3LtF=TC#YP5
z1PaJQpn|g)B*ekU!o<VDA;82az|Y4A&JUVwMVcV5g2K57<c}gy61~L+DGR_#kjw=e
a4)#BX4a6gMpr9)TH61vZI2d`Dc$fk7oUlOv

diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py
index e400d4e..6df55ed 100644
--- a/SuperTagger/Linker/AtomTokenizer.py
+++ b/SuperTagger/Linker/AtomTokenizer.py
@@ -1,5 +1,7 @@
 import torch
 
+from SuperTagger.utils import pad_sequence
+
 
 class AtomTokenizer(object):
     def __init__(self, atom_map, max_atoms_in_sentence):
@@ -28,24 +30,3 @@ class AtomTokenizer(object):
 
     def convert_ids_to_atoms(self, ids):
         return [self.inverse_atom_map[int(i)] for i in ids]
-
-
-def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
-    max_size = sequences[0].size()
-    trailing_dims = max_size[1:]
-    if batch_first:
-        out_dims = (len(sequences), max_len) + trailing_dims
-    else:
-        out_dims = (max_len, len(sequences)) + trailing_dims
-
-    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
-    for i, tensor in enumerate(sequences):
-        length = tensor.size(0)
-        # use index notation to prevent duplicate references to the tensor
-        if batch_first:
-            out_tensor[i, :length, ...] = tensor
-        else:
-            out_tensor[:length, i, ...] = tensor
-
-    return out_tensor
-
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 6b5c6f1..6a39ae7 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -2,6 +2,7 @@ from itertools import chain
 
 import torch
 from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
+from torch.nn import Module
 
 from Configuration import Configuration
 from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
@@ -10,11 +11,12 @@ from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
 from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer
+from SuperTagger.utils import pad_sequence
 
 
-
-class Linker:
+class Linker(Module):
     def __init__(self):
+        super(Linker, self).__init__()
         self.__init__()
 
         self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
@@ -71,20 +73,25 @@ class Linker:
         atoms_polarity = find_pos_neg_idexes(category_batch)
 
         link_weights = []
-        for sentence_idx in range(len(atoms_polarity)):
-            for atom_type in self.atom_map.keys():
-                pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                         x and atoms_batch[sentence_idx][i] == atom_type]
-                neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                         not x and atoms_batch[sentence_idx][i] == atom_type]
-
-                pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
-                neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
-
-                pos_encoding = self.pos_transformation(pos_encoding)
-                neg_encoding = self.neg_transformation(neg_encoding)
-
-                weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
-                link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
+        for atom_type in self.atom_map.keys():
+            pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                                      x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+            neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                                      not x and atoms_batch[s_idx][i] == atom_type] for s_idx in
+                                     range(len(atoms_polarity))]
+
+            # to do select with list of list
+            pos_encoding = pad_sequence(
+                [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+                 for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
+            neg_encoding = pad_sequence(
+                [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+                 for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0)
+
+            # pos_encoding = self.pos_transformation(pos_encoding)
+            # neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
 
         return link_weights
diff --git a/SuperTagger/Linker/Sinkhorn.py b/SuperTagger/Linker/Sinkhorn.py
index 912abb4..9cf9b45 100644
--- a/SuperTagger/Linker/Sinkhorn.py
+++ b/SuperTagger/Linker/Sinkhorn.py
@@ -1,4 +1,3 @@
-
 from torch import logsumexp
 
 
diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
index afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95..26d18c0959a2b1114a5564ff8b1ec4304f81acf3 100644
GIT binary patch
delta 56
zcmZ3_x}KFcl$V!_fq{X+bw*K=%0^y(Mn;y&vW$(49FylWE@tGJ?93F)&&a^QP|U`_
Mz`(=I!NS1;07MT7ng9R*

delta 56
zcmZ3_x}KFcl$V!_fq{Wx#@~V@)s4LTjEt<4Wf>b8IVaC&T+GNj*_kPppNWBip_q+<
Mfq{pagN1_y0Aj}qRR910

diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index ddb8cb5..f2e72e1 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -92,4 +92,3 @@ def find_pos_neg_idexes(batch_symbols):
             list_symbols.append(cut_category_in_symbols(category))
         list_batch.append(list_symbols)
     return list_batch
-
diff --git a/SuperTagger/Symbol/SymbolEmbedding.py b/SuperTagger/Symbol/SymbolEmbedding.py
deleted file mode 100644
index b982ef0..0000000
--- a/SuperTagger/Symbol/SymbolEmbedding.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import torch
-from torch.nn import Module, Embedding
-
-
-class SymbolEmbedding(Module):
-    def __init__(self, dim_decoder, atom_vocab_size, padding_idx):
-        super(SymbolEmbedding, self).__init__()
-        self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_decoder, padding_idx=padding_idx,
-                             scale_grad_by_freq=True)
-
-    def forward(self, x):
-        return self.emb(x)
diff --git a/SuperTagger/Symbol/SymbolTokenizer.py b/SuperTagger/Symbol/SymbolTokenizer.py
deleted file mode 100644
index cded840..0000000
--- a/SuperTagger/Symbol/SymbolTokenizer.py
+++ /dev/null
@@ -1,53 +0,0 @@
-
-import torch
-
-
-class SymbolTokenizer(object):
-    def __init__(self, symbol_map, max_symbols_in_sentence, max_len_sentence):
-        self.symbol_map = symbol_map
-        self.max_symbols_in_sentence = max_symbols_in_sentence
-        self.max_len_sentence = max_len_sentence
-        self.inverse_symbol_map = {v: k for k, v in self.symbol_map.items()}
-        self.sep_token = '[SEP]'
-        self.pad_token = '[PAD]'
-        self.sos_token = '[SOS]'
-        self.sep_token_id = self.symbol_map[self.sep_token]
-        self.pad_token_id = self.symbol_map[self.pad_token]
-        self.sos_token_id = self.symbol_map[self.sos_token]
-
-    def __len__(self):
-        return len(self.symbol_map)
-
-    def convert_symbols_to_ids(self, symbol):
-        return self.symbol_map[str(symbol)]
-
-    def convert_sents_to_ids(self, sentences):
-        return torch.as_tensor([self.convert_symbols_to_ids(symbol) for symbol in sentences])
-
-    def convert_batchs_to_ids(self, batchs_sentences):
-        return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
-                                            max_len=self.max_symbols_in_sentence, padding_value=self.pad_token_id))
-
-    def convert_ids_to_symbols(self, ids):
-        return [self.inverse_symbol_map[int(i)] for i in ids]
-
-
-def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
-    max_size = sequences[0].size()
-    trailing_dims = max_size[1:]
-    if batch_first:
-        out_dims = (len(sequences), max_len) + trailing_dims
-    else:
-        out_dims = (max_len, len(sequences)) + trailing_dims
-
-    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
-    for i, tensor in enumerate(sequences):
-        length = tensor.size(0)
-        # use index notation to prevent duplicate references to the tensor
-        if batch_first:
-            out_tensor[i, :length, ...] = tensor
-        else:
-            out_tensor[:length, i, ...] = tensor
-
-    return out_tensor
-
diff --git a/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc
deleted file mode 100644
index 030ce696540244b363763b0b592d2a84a8c19fd9..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 858
zcmWIL<>g{vU|^UZ>zlNck%8ech=Yt-7#J8F7#J9ebr={JQW#Pga~PsPG*b>^E>jd!
zE^`z!BZE6b3Udle3quM^DpNCa6iW(YFoP!ROOQE!noPIYeDhOEb5d_{y5=UOrle%%
zr6+@=kTElqQ>?<kz>vxi#hAhn#njG_#+bsG!qmbM#SF0^ilv=_g&~R+Wc)4u;L6;j
z{2YX#Ah$xzVF0n&7#J8p27?VPVJKm&VQ6Mrz_gHok)edShN*_Jh8bjhFG~$WJWC2g
zFoPzuUx+5lExx?c-1t<OTZ?b;!rAdDnYp*P3lhPeh|f%^xFuAaoS2gupI(%h5}#BV
zpH`GwsL6VZwYan(wWtW>h+7=-@tJv<CGqh^Ah+CNPR&iyWVyvsoSKt%i#sJVH$Ejb
zIX@+}D2hL^BtJL4EI&ChDZV(fDz%86fq`Kq<1Nnk_~e|#;^O%Dl?=Z;^)vEwQ}vTm
zQVa4+i%RsJN{dnxOZ8JyQw!3I5;Jq+gZ%vTgF&{3B&MgQ7U@F*2};4kO0S@@2o&+4
zAS&hn1sWqCBL^eX|0)TH282qO^OKoC%Agp;26@LBlr%sF)i5kzs9{_Pig-ppO~xW_
z1_lOArXn5&28LUV6-5FJ3=9xL5G2B$mS0q!Sd;=%RSYtVfw4*)-D0R3O*T*v<|XE)
z#>d~{ijU6)C#LxLTRidcg{6r(5E-y<iiAOi^FqTbIW;FIJ|1E#*nU2cBS7|pBZd{^
zbWjX&F!Hd1ML-c(l3$dZaf<^?>E-1WfgA^pIIs#N!@wqj47<f)196faBf>-hW&k<w
B(5(Of

diff --git a/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc
deleted file mode 100644
index 7c631aa22d6bb2cee2970ecccb0a089a347c9f95..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 2862
zcmWIL<>g{vU|`r0>zj0voq^#oh=Yt-7#J8F7#J9e)fgBUQW#Pga~Pr++!<1sQkYv9
zQkYX2o0+4SQaDoBQrKG<qnJ||gBdhAUxG~VO9qk1m<eW<4Fdy1Dnk@w3PTiA3S$aW
z3qurh3Udle3qur33Tp~m3qurZ3VRAi3qurJ3TFye3qurp3U>-m3quq~3U4rjCf_ao
z;L6;j{G5>d?9{x>s??%nkV9dXf!NFp3=Ga7hi5P_Fw`*CFvK&|Fx4=`GuANIFvLTx
ztzpPwSiroH!G)n2q=F>{B*&V<Qo|6>mcm-Y5YG-$QNxhMD#=j85YGYSv4MG<P#!y&
z$Cbhm%%I8Um&^!r1p@;E$PpmN1cMwR#=yXk$xy=(%NN5`%UH`)!&JjKnW>N^m|-PD
z5lE#b^DQPlgIkQ*w;0R70$}2or+!9$ZmNEAN@_uVX;F#3Q)y9ZVyS*gYHC4xQDSCJ
ze2|}?esF0)YEejHdU|S+J|t+M6g-6V3My}L*`#D9mn7%s7TAHJy_k)Gfq{*Ije&`w
zN(Z6~p*%i5GcU6wK3>lzCqFqcr`S#pp}C4RI@mQJwu&`6z|jRn1^WlbYI5J=Dh9_&
zd~RaFE%DsMig*aWI6gBkzBo0nBsDKN^_BogA}19|gf+7yHMjVdP-b3PYEf}2R2SG3
z&f?U9_!3Yo-{LGtOo1??IE(X(A<SDma0T(1DYtmwYT`3fqIlrS;xkh;S#Gfur{<&;
zaWXJ46oI^7#0_GBLLMB@MeGa=47WI7{sx6ju?Q$K8QH*43{0{y@-gx;R!QIrE|mBO
zr3Fx`2PZ_3l_d-{3|WlLjKK^m8T~YwZZYSi=7AlwlCek#i)+~9<3S-650WbenZ&?Y
zC5~w|SWPlm9^_yK1_lrt<TJ3%e4s4Aki}5KSi{)ND9I4aPz3UrCKK4|Tg=5JMVd@t
z`){#9Vhj=>AOj$wB9@#FiklK>)PMp%J~O3Q6qK$(RxvPEsi8XyyA}*rvO@Dou>=DH
zLkYtI#&(7@MsTUXQNswzm<vH=3NvcP0@(@DC=7A{C}X9AGgd8Q2g3q}8pef;6Tuk^
z><vxEA}Iz222G|SP*yCG1O+ol3s@53JuaJ^%wkBs07(~v6oPYw1x`<;=9R!g0V#*T
zjnQPf#afbIl$>#kGqE_nBsH%%9~Q<@oUp7=3~>V3G2p}}0uM7J<3Zs9i+dGf+=CIs
zAhSVn2`<DzL0rOw9>^UGS<D>_Su82cz0BYMXQ3#-<q!dWixuSE;v%pwL9PG?d66^&
z1A`1G_TXV{iOX+Ei6zMy1Ol5KTBzLOEl5mB$;?ZSFH6iRP1R&A0$EZd3(DHy!l5{|
zuoP4bK}rBYh&w^qD@p)r56HP76&SH23X2`6!jgPY9sqfym;;pU7{wS_7+C&SsS@Q8
zj93DN4G0&5T!S7z!3>)CvH&Q7gR_7TC<`ERI&wB(D&hr&0&`|w2{bR<V$8%AEea^n
zVui~$nJM5nffN|U*fIe)dBaK;=FAj~5QP^jsAkE7901DV_=*)&Eq<DuprSi3F*h|n
z{uWn!d~SY9X%2|Z6CYn#nwSHTDFPK3MdF|$02E|JpmbWK4C1PT2vG2Y9StEs`KAcu
z4p323%)`LIz`@AD!@&%OEMS?CmlGHn7*JXtpjrh~d4P-FdElZqg;A1W0mDLuT9z7?
z6s8o0RwhY?TGkrIBE1sE8c^GVIgP1?Rh*%gEl;3^86pM~sbK)=&SJ`9u3_tASin-l
zw2-luy@VB{&xIk@CWfh&qn5LVa{)sN<3h$-t{TRowi?bF_8N{9h8nI`rWEEhW*dgW
z7?2o5q=o}7#vaU|$>LW8s;6Fp2t7@<TP($(1{J7izr~W0Sdw^)IWM*R7He8&PEPzS
z&eXip+|;7Pl2lFZTb!`WRD6pYoG0SbGK-2!iWEU%2}(t`I6&DFWZo^_5>Q(Olrd8>
zbBk|r<d>GjgIQc47NmeIQUxVYHVEeyTTW_TdPxQ(D!?%Vjt?aU1_n@!f$IwfMjl2U
zMixdPCIKcfMgc}HCKg5!CK0A0kSa~)5U|gRv_ZL?1zZx`V#`m;N=+^SM;=%elKF5;
UZgJQ^(vKY|_QBNw2eW`200G;je*gdg

diff --git a/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc
deleted file mode 100644
index 1e1195f2ba885473f1d89b05684939ab6a2385b1..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 563
zcmWIL<>g{vU|`rJ8j^I7k%8ech=Yu!85kHG7#J9e?HCvsQW#PgQ<zeiQ&>`1Q`l12
zQ#evMQ@B#NQ+QH%Q}|N&Qv^~3Q-o54Q$$ikQ^ZolQzTL(Q>5ldMKPpE2Qz5Oyabv2
zl97Rd;UyD@U<MH^Ac7S{uz?765WxW=I6(x+buYO=EFKWS3nKVH1V4xn01<*9LI^|%
zg9s52AqpbIK!iAmkN^>q3=9mKQc*0)Iq^lm7-McR=G|f}h+;2}FDNKVExyH^l9?FA
zS{z@VQ5eOXlUnkNQ6I$3%u9=6D~>NnEG~{>DJY0fjAAP&h%ZPiiefD&h)+ocn^$>@
zIVr#57IR5O$t}iWh?$8+B~i@9@rAdT@(OM-6%<slMhCkF#DXaQ;Mgkm=-?2?pb!v0
zz|kevPm}Q$S8-)-QhrW+Zeqboh9Wfv1_<%XML#1yH&s75CAA>Gw5UYiskA6Hu~a`L
zHMJnUC^0i9KFH5cKe)6YwJ0PpJw3HZKNxI|KEgP?g34PQHo5sJr8%i~pr9`{V_;xl
LVk8+pXZa5R2VIM`

diff --git a/SuperTagger/Symbol/symbol_map.py b/SuperTagger/Symbol/symbol_map.py
deleted file mode 100644
index c16b8fd..0000000
--- a/SuperTagger/Symbol/symbol_map.py
+++ /dev/null
@@ -1,28 +0,0 @@
-symbol_map = \
-    {'cl_r': 0,
-     '\\': 1,
-     'n': 2,
-     'p': 3,
-     's_ppres': 4,
-     'dia': 5,
-     's_whq': 6,
-     'let': 7,
-     '/': 8,
-     's_inf': 9,
-     's_pass': 10,
-     'pp_a': 11,
-     'pp_par': 12,
-     'pp_de': 13,
-     'cl_y': 14,
-     'box': 15,
-     'txt': 16,
-     's': 17,
-     's_ppart': 18,
-     's_q': 19,
-     'np': 20,
-     'pp': 21,
-     '[SEP]': 22,
-     '[SOS]': 23,
-     '[START]': 24,
-     '[PAD]': 25
-     }
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index 3017c7f..372e68c 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -1,8 +1,7 @@
 import torch
 from torch import Tensor
 from torch.nn import Module
-from torch.nn.functional import cross_entropy
-
+from torch.nn.functional import nll_loss, cross_entropy
 
 # Another from Kokos function to calculate the accuracy of our predictions vs labels
 def measure_supertagging_accuracy(pred, truth, ignore_idx=0):
@@ -42,3 +41,12 @@ class NormCrossEntropy(Module):
     def forward(self, predictions, truths):
         return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights,
                              reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id)
+
+
+class SinkhornLoss(Module):
+    def __init__(self):
+        super(SinkhornLoss, self).__init__()
+
+    def forward(self, predictions, truths):
+        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
+                   for link, perm in zip(predictions, truths))
\ No newline at end of file
diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py
index 8712cca..cfacf25 100644
--- a/SuperTagger/utils.py
+++ b/SuperTagger/utils.py
@@ -5,6 +5,26 @@ import torch
 from tqdm import tqdm
 
 
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
+
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
+
+
 def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
     print("\n" + "#" * 20)
     print("Loading csv...")
diff --git a/test.py b/test.py
index f208027..9e14d08 100644
--- a/test.py
+++ b/test.py
@@ -1,27 +1,54 @@
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
 import torch
 
-atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
-               ["np", "np", "v", "v","np", "np", "v", "v"]]
 
-atoms_polarity = [[False, True, True, False,False, True, True, False],
-                  [True, False, True, False,True, False, True, False]]
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
 
-atoms_encoding = torch.randn((2, 8, 24))
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
 
-matches = []
-for sentence_idx in range(len(atoms_polarity)):
 
-    for atom_type in ["np", "v"]:
-        pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                 x and atoms_batch[sentence_idx][i] == atom_type]
-        neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
-                                 not x and atoms_batch[sentence_idx][i] == atom_type]
+atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"],
+               ["np", "np", "v", "v"]]
 
-        pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
-        neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
+atoms_polarity = [[False, True, True, False, False, True, True, False],
+                  [True, False, True, False]]
 
-        weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
-        matches.append(sinkhorn(weights, iters=3))
+atoms_encoding = torch.randn((2, 8, 24))
+
+matches = []
+for atom_type in ["np", "v"]:
+    pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                              x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+    neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if
+                              not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))]
+
+    # to do select with list of list
+    pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+            for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0)
+    neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence))
+            for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0)
+
+    print(neg_encoding.shape)
+
+    weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+    print(weights.shape)
+    print("sinkhorn")
+    print(sinkhorn(weights, iters=3).shape)
+    matches.append(sinkhorn(weights, iters=3))
 
 print(matches)
diff --git a/train.py b/train.py
index 58ebe45..25154db 100644
--- a/train.py
+++ b/train.py
@@ -1,134 +1,51 @@
 import os
+import pickle
 import time
 from datetime import datetime
 
 import numpy as np
 import torch
 import torch.nn.functional as F
-import transformers
 from torch.optim import SGD, Adam, AdamW
 from torch.utils.data import Dataset, TensorDataset, random_split
-from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
-from transformers import (CamembertModel)
+from transformers import get_cosine_schedule_with_warmup
 
 from Configuration import Configuration
-from SuperTagger.Encoder.EncoderInput import EncoderInput
-from SuperTagger.EncoderDecoder import EncoderDecoder
-from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer
-from SuperTagger.Symbol.symbol_map import symbol_map
-from SuperTagger.eval import NormCrossEntropy
+from SuperTagger.Linker.Linker import Linker
+from SuperTagger.Linker.atom_map import atom_map
+from SuperTagger.eval import NormCrossEntropy, SinkhornLoss
 from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load
 
 from torch.utils.tensorboard import SummaryWriter
 
-transformers.TOKENIZERS_PARALLELISM = True
 torch.cuda.empty_cache()
 
 # region ParamsModel
 
-max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence'])
-symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size'])
-num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers'])
+max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+atom_vocab_size = int(Configuration.datasetConfig['atoms_vocab_size'])
 
 # endregion ParamsModel
 
 # region ParamsTraining
 
-file_path = 'Datasets/m2_dataset.csv'
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
 nb_sentences = batch_size * 40
 epochs = int(Configuration.modelTrainingConfig['epoch'])
 seed_val = int(Configuration.modelTrainingConfig['seed_val'])
 learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
-loss_scaled_by_freq = True
 
 # endregion ParamsTraining
 
-# region OutputTraining
-
-outpout_path = str(Configuration.modelTrainingConfig['output_path'])
-
-training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
-logs_dir = os.path.join(training_dir, 'logs')
-
-checkpoint_dir = training_dir
-writer = SummaryWriter(log_dir=logs_dir)
-
-use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE'))
-
-# endregion OutputTraining
-
-# region InputTraining
-
-input_path = str(Configuration.modelTrainingConfig['input_path'])
-model_to_load = str(Configuration.modelTrainingConfig['model_to_load'])
-model_to_load_path = os.path.join(input_path, model_to_load)
-use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD'))
-
-# endregion InputTraining
-
-# region Print config
-
-print("##" * 15 + "\nConfiguration : \n")
-
-print("ParamsModel\n")
-
-print("\tsymbol_vocab_size :", symbol_vocab_size)
-print("\tbidirectional : ", False)
-print("\tnum_gru_layers : ", num_gru_layers)
-
-print("\n ParamsTraining\n")
-
-print("\tDataset :", file_path)
-print("\tb_sentences :", nb_sentences)
-print("\tbatch_size :", batch_size)
-print("\tepochs :", epochs)
-print("\tseed_val :", seed_val)
-
-print("\n Output\n")
-print("\tuse checkpoint save :", use_checkpoint_SAVE)
-print("\tcheckpoint_dir :", checkpoint_dir)
-print("\tlogs_dir :", logs_dir)
-
-print("\n Input\n")
-print("\tModel to load :", model_to_load_path)
-print("\tLoad checkpoint :", use_checkpoint_LOAD)
-
-print("\nLoss and optimizer : ")
-
-print("\tlearning_rate :", learning_rate)
-print("\twith loss scaled by freq :", loss_scaled_by_freq)
-
-print("\n Device\n")
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-print("\t", device)
-
-print()
-print("##" * 15)
-
-# endregion Print config
-
-# region Model
+# region Data loader
 
 file_path = 'Datasets/m2_dataset.csv'
-BASE_TOKENIZER = AutoTokenizer.from_pretrained(
-    'camembert-base',
-    do_lower_case=True)
-BASE_MODEL = CamembertModel.from_pretrained("camembert-base")
-symbols_tokenizer = SymbolTokenizer(symbol_map, max_len_sentence, max_len_sentence)
-sents_tokenizer = EncoderInput(BASE_TOKENIZER)
-model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map)
-model = model.to("cuda" if torch.cuda.is_available() else "cpu")
+file_path_axiom_links = 'Datasets/axiom_links.csv'
 
-# endregion Model
-
-# region Data loader
 df = read_csv_pgbar(file_path, nb_sentences)
+df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
-symbols_tokenized = symbols_tokenizer.convert_batchs_to_ids(df['sub_tree'])
-sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentences'].tolist())
-
-dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized)
+dataset = TensorDataset(df, df, df_axiom_links)
 
 # Calculate the number of samples to include in each set.
 train_size = int(0.9 * len(dataset))
@@ -137,46 +54,34 @@ val_size = len(dataset) - train_size
 # Divide the dataset by randomly selecting samples.
 train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
 
-print('{:>5,} training samples'.format(train_size))
-print('{:>5,} validation samples'.format(val_size))
-
 training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
 
 # endregion Data loader
 
+
+# region Models
+
+supertagger_path = ""
+supertagger = pickle.load(supertagger_path)
+linker = Linker()
+
+# endregion Models
+
 # region Fit tunning
 
 # Optimizer
-optimizer_encoder = AdamW(model.encoder.parameters(),
-                          weight_decay=1e-5,
-                          lr=5e-5)
-optimizer_decoder = AdamW(model.decoder.parameters(),
-                          weight_decay=1e-5,
-                          lr=learning_rate)
-
-# Total number of training steps is [number of batches] x [number of epochs].
-# (Note that this is not the same as the number of training samples).
-total_steps = len(training_dataloader) * epochs
+optimizer_linker = AdamW(linker.parameters(),
+                         weight_decay=1e-5,
+                         lr=learning_rate)
 
 # Create the learning rate scheduler.
-scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder,
-                                                    num_warmup_steps=0,
-                                                    num_training_steps=5)
-scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder,
-                                                    num_warmup_steps=0,
-                                                    num_training_steps=total_steps)
+scheduler_linker = get_cosine_schedule_with_warmup(optimizer_linker,
+                                                   num_warmup_steps=0,
+                                                   num_training_steps=100)
 
 # Loss
-if loss_scaled_by_freq:
-    weights = torch.as_tensor(
-        [6.9952, 1.0763, 1.0317, 43.274, 16.5276, 11.8821, 28.2416, 2.7548, 1.0728, 3.1847, 8.4521, 6.77, 11.1887,
-         6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0],
-        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id,
-                                          weights=weights)
-else:
-    cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id)
+cross_entropy_loss = SinkhornLoss()
 
 np.random.seed(seed_val)
 torch.manual_seed(seed_val)
@@ -192,10 +97,6 @@ total_t0 = time.time()
 
 validate = True
 
-if use_checkpoint_LOAD:
-    model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path)
-    epochs = epochs - last_epoch
-
 
 def run_epochs(epochs):
     # For each epoch...
@@ -216,60 +117,38 @@ def run_epochs(epochs):
         # Reset the total loss for this epoch.
         total_train_loss = 0
 
-        model.train()
+        linker.train()
 
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
+            # Unpack this training batch from our dataloader.
+            batch_categories = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_sentences = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+            batch_axiom_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+
+            optimizer_linker.zero_grad()
+
+            # Find the prediction of categories to feed the linker and the sentences embedding
+            category_logits_pred, sents_embedding, sents_mask = supertagger(batch_categories, batch_sentences)
+
+            # Predict the categories from prediction with argmax and softmax
+            category_batch = torch.argmax(torch.nn.functional.softmax(category_logits_pred, dim=2), dim=2)
 
-            # if epoch_i == 0 and step == 0:
-            #     writer.add_graph(model, input_to_model=batch[0], verbose=False)
-
-            # Progress update every 40 batches.
-            if step % 40 == 0 and not step == 0:
-                # Calculate elapsed time in minutes.
-                elapsed = format_time(time.time() - t0)
-                # Report progress.
-                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(training_dataloader), elapsed))
-
-                # Unpack this training batch from our dataloader.
-            b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-            b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-            b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-
-            optimizer_encoder.zero_grad()
-            optimizer_decoder.zero_grad()
-
-            logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized)
-
-            predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in
-                            torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]]
-            true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]]
-            l = len([i for i in true_trad if i != '[PAD]'])
-            if step % 40 == 0 and not step == 0:
-                writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str(
-                    [token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str(
-                    len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str(
-                    [token for token in predict_trad[:l] if token != '[PAD]']))
-
-            loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized)
+            # Run the kinker on the categories predictions
+            logits_predictions = linker(category_batch, sents_embedding, sents_mask)
+
+            linker_loss = cross_entropy_loss(logits_predictions, batch_axiom_links)
             # Perform a backward pass to calculate the gradients.
-            total_train_loss += float(loss)
-            loss.backward()
+            total_train_loss += float(linker_loss)
+            linker_loss.backward()
 
             # This is to help prevent the "exploding gradients" problem.
             # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
 
             # Update parameters and take a step using the computed gradient.
-            optimizer_encoder.step()
-            optimizer_decoder.step()
-
-            scheduler_encoder.step()
-            scheduler_decoder.step()
-
-        # checkpoint
+            optimizer_linker.step()
 
-        if use_checkpoint_SAVE:
-            checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss)
+            scheduler_linker.step()
 
         avg_train_loss = total_train_loss / len(training_dataloader)
 
@@ -277,27 +156,18 @@ def run_epochs(epochs):
         training_time = format_time(time.time() - t0)
 
         if validate:
-            model.eval()
+            linker.eval()
             with torch.no_grad():
                 print("Start eval")
-                accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss)
+                accuracy_sents, accuracy_atom, v_loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss)
                 print("")
                 print("  Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents))
-                print("  Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol))
-                writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1)
-                writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1)
+                print("  Average accuracy atom on epoch: {0:.2f}".format(accuracy_atom))
 
         print("")
         print("  Average training loss: {0:.2f}".format(avg_train_loss))
         print("  Training epcoh took: {:}".format(training_time))
 
-        # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1)
-
-        writer.add_scalars('Training vs. Validation Loss',
-                           {'Training': avg_train_loss, 'Validation': v_loss},
-                           epoch_i + 1)
-        writer.flush()
-
 
 run_epochs(epochs)
 # endregion Train
-- 
GitLab