From c7baf521e43501d7b042717461444f6c426c2a2e Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 3 May 2022 17:34:31 +0200
Subject: [PATCH] progress on linker

---
 SuperTagger/Linker/AtomEmbedding.py           |  12 +++
 SuperTagger/Linker/AtomTokenizer.py           |  51 +++++++++++
 SuperTagger/Linker/Linker.py                  |  83 ++++++++++--------
 .../__pycache__/AtomTokenizer.cpython-38.pyc  | Bin 0 -> 2801 bytes
 .../__pycache__/Sinkhorn.cpython-38.pyc       | Bin 0 -> 687 bytes
 .../__pycache__/atom_map.cpython-38.pyc       | Bin 0 -> 571 bytes
 SuperTagger/Linker/atom_map.py                |  28 ++++++
 SuperTagger/Linker/utils.py                   |  81 ++++++++++-------
 test.py                                       |  30 +++++--
 9 files changed, 211 insertions(+), 74 deletions(-)
 create mode 100644 SuperTagger/Linker/AtomEmbedding.py
 create mode 100644 SuperTagger/Linker/AtomTokenizer.py
 create mode 100644 SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
 create mode 100644 SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
 create mode 100644 SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
 create mode 100644 SuperTagger/Linker/atom_map.py

diff --git a/SuperTagger/Linker/AtomEmbedding.py b/SuperTagger/Linker/AtomEmbedding.py
new file mode 100644
index 0000000..e7be599
--- /dev/null
+++ b/SuperTagger/Linker/AtomEmbedding.py
@@ -0,0 +1,12 @@
+import torch
+from torch.nn import Module, Embedding
+
+
+class AtomEmbedding(Module):
+    def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
+        super(AtomEmbedding, self).__init__()
+        self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
+                             scale_grad_by_freq=True)
+
+    def forward(self, x):
+        return self.emb(x)
diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py
new file mode 100644
index 0000000..e400d4e
--- /dev/null
+++ b/SuperTagger/Linker/AtomTokenizer.py
@@ -0,0 +1,51 @@
+import torch
+
+
+class AtomTokenizer(object):
+    def __init__(self, atom_map, max_atoms_in_sentence):
+        self.atom_map = atom_map
+        self.max_atoms_in_sentence = max_atoms_in_sentence
+        self.inverse_atom_map = {v: k for k, v in self.atom_map.items()}
+        self.sep_token = '[SEP]'
+        self.pad_token = '[PAD]'
+        self.sos_token = '[SOS]'
+        self.sep_token_id = self.atom_map[self.sep_token]
+        self.pad_token_id = self.atom_map[self.pad_token]
+        self.sos_token_id = self.atom_map[self.sos_token]
+
+    def __len__(self):
+        return len(self.atom_map)
+
+    def convert_atoms_to_ids(self, atom):
+        return self.atom_map[str(atom)]
+
+    def convert_sents_to_ids(self, sentences):
+        return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
+
+    def convert_batchs_to_ids(self, batchs_sentences):
+        return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
+                                            max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
+
+    def convert_ids_to_atoms(self, ids):
+        return [self.inverse_atom_map[int(i)] for i in ids]
+
+
+def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    max_size = sequences[0].size()
+    trailing_dims = max_size[1:]
+    if batch_first:
+        out_dims = (len(sequences), max_len) + trailing_dims
+    else:
+        out_dims = (max_len, len(sequences)) + trailing_dims
+
+    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+    for i, tensor in enumerate(sequences):
+        length = tensor.size(0)
+        # use index notation to prevent duplicate references to the tensor
+        if batch_first:
+            out_tensor[i, :length, ...] = tensor
+        else:
+            out_tensor[:length, i, ...] = tensor
+
+    return out_tensor
+
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index 6568230..745b7d9 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -4,9 +4,11 @@ import torch
 from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU
 
 from Configuration import Configuration
-
+from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
+from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
+from SuperTagger.Linker.atom_map import atom_map
 from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs
+from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch
 
 
 def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None):
@@ -24,56 +26,67 @@ class Linker:
 
         self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder'])
         self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder'])
+        self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker'])
+        self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence'])
+        self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size'])
 
         self.dropout = Dropout(0.1)
 
+        self.atom_map = atom_map
+        self.padding_id = self.atom_map['[PAD]']
+        self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+        self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id)
+
+        # to do : definit un encoding
+        self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1)
+
         self.pos_transformation = Sequential(
-            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
-            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+            FFN(self.dim_decoder, self.dim_decoder, 0.1),
+            LayerNorm(self.dim_decoder, eps=1e-12)
         )
         self.neg_transformation = Sequential(
-            FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2),
-            LayerNorm(self.dim_decoder // 2, eps=1e-12)
+            FFN(self.dim_decoder, self.dim_decoder, 0.1),
+            LayerNorm(self.dim_decoder, eps=1e-12)
         )
 
-    def forward(self, symbols_batch, symbols_decoding):
+    def forward(self, category_batch):
         '''
         Parameters :
         symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder
+        Retturns :
+        link_weights : batch-size, atom_vocab_size, ...)
         '''
 
-        # some sequential for linker with output of decoder and initial ato
-
-        # decompose into batch_size, max symbols in sentence
-        decompose_decoding = find_pos_neg_idexes(symbols_batch)
-
-        # get  tensors of shape (batch_size, max_symbols_in_sentence/2)
-        pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding))
-        neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding))
-
-        _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch)
-        _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch)
+        # atoms embedding
+        atoms_batch = get_atoms_batch(category_batch)
+        atoms_batch = self.atom_tokenizer.convert_batchs_to_ids(atoms_batch)
+        atoms_embedding = self.atom_embedding(atoms_batch)
 
-        positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0]
-        negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0]
+        # MHA ou LSTM avec sortie de BERT
+        #
+        # TO DO
+        # atoms_encoding = self.linker_encoder(atoms_embedding)
+        #
+        atoms_encoding = atoms_embedding
 
-        distinct_shapes = {tensor.size()[0] for tensor in positives}
-        distinct_shapes = sorted(distinct_shapes)
+        # find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence)
+        atoms_polarity = find_pos_neg_idexes(category_batch)
 
-        # going to match the pos and neg together
-        matches = []
+        link_weights = []
+        for sentence_idx in range(len(atoms_polarity)):
+            for atom_type in self.atom_map.keys():
+                pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                         x and atoms_batch[sentence_idx][i] == atom_type]
+                neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                         not x and atoms_batch[sentence_idx][i] == atom_type]
 
-        all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives
-                                                                                 if tensor.size()[0] == shape])))
-                               for shape in distinct_shapes]
+                pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
+                neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
 
-        all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives
-                                                                                 if tensor.size()[0] == shape])))
-                               for shape in distinct_shapes]
+                pos_encoding = self.pos_transformation(pos_encoding)
+                neg_encoding = self.neg_transformation(neg_encoding)
 
-        for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives):
-            weights = torch.bmm(this_shape_positives,
-                                this_shape_negatives.transpose(2, 1))
-            matches.append(sinkhorn(weights, iters=3))
+                weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
+                link_weights.append(sinkhorn(weights, iters=3))
 
-        return matches
+        return link_weights
diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb55c03f688748485e4452c56dda80e12c73c904
GIT binary patch
literal 2801
zcmWIL<>g{vU|^8YEKCY#XJB{?;vi!d1_lNP1_p*=H3kNT6owSW9EK<ccZL+E6y_F&
z6y{XMX67iS6pj?O6!sRzDCQK#U<OUjmmm}Tl0hUgW`dbz!@$6h$`Hkv!Vtxj!kEI;
z!Vtxr!kog=!Vtxh!kWU?!Vtxp!k)s>!Vtxl!kNO=!Vtxt!kxm?!Vtxg!W+z>$#;v_
zu_Qk?BtJVfFS9DOD47|_9uOOZok7k{VqjpXVXR?@XQ*MSVTfmhSXaZ4#jt>BA%hD;
zGf0j(g}H_yo+X8)h9RCcg|&tuo(-g;h9QeZlA(qno*l|#1@kzdJT@?oGle0TL6gld
znGwwippXj&IYEqpfgzKjh9Q<OhN+gZmZ^rRhH)}eAxkjBN`@kkLp7OiG3goHV$8n9
zSOyjV6TgD>GxBp&^^;Rl3-U{gO7xvdi&7Iy^;1$)3(|`cGjrmD{QTm}GfOh!b29U?
zQ;YP2OAAtqLK4%{Q;YO{AVMgiq*qXRi_0b@Gr1%=Kexb+1LQb11_lN;1~vvJhAJ&o
z#qsf(d6^~g@p?8n`N@en#ddlKtyQei!L9+ZRjkngjxHc7*grT{lj{~oVo82(d~RaF
zEz#V>ig*yOI6gBkzBo0nBsDKN^%iSpNosEKErHCuvecsDRImz=`Y6ug)PndDP<%#l
z79^&?n8o?UP$mytL40P)Egra<_{@|j9=Njj%oI)LTP($?IcY_l3=9lKpwtZxP6)xy
zz`$^e1LieWPzVTt;*F7wk&RIdOtLZZG4e51iD3^7lxPM;IVe$sQyR#o5{4RvEXHQW
zV1|{9ews|Tm~&F|!0uSdSi}c%4k)^jT*4k7pOcywA1?rM2*@M`#wt;Ci@_?A!LlHy
zGB7ZJ*dUX@_JYg>rN|P-8pdWuNrqsCB9L*KOkj&|F&CE<X)=Lrzr_NI5Qz04?M2|w
z5J}DlMMeoYE<m9jpP5oDjN~9CR0m;Gj^R2`$bzgWmSA9DC}CK@*v^m!$$K0%jG$b%
z5L8ewqvjfrw?G=fmV<IlIyl$VGIlU5V5nhS$T$(4Yrr1RWGn)y&}1qC<+CDDP~d`M
z5-bVv7?({>W-%nGi-WuaQV34yX4pLhN~AFVBPDvcA(~9LSWEJYk~3~`CKkt+q~;ap
z!$KC6d?6{k7~%l1Q&58!*?5rCi$UIHV64IsyvWKif)`{mICw#&GAMXUn9u{bgCUE#
zgCUD0g}Ij*9M~)r1-29e0|PX$qgcUiD+2ow<O*<b7lFbUq^cMm&gR&|IVrIuIfFnz
zvxAC{oYcHqyakCVDVcfc@nwlQrKy^%MW6&#BnisK;4+{%wXhUavOqFAFT|CgED<FD
zwFl%_kP3{L5rw5>sKSzbP|61xT?{H385qSFSr}OUS1A+W4U8B9g$oE5gF*>CZh{#!
z@g;pwD1wu|5Gd(qF~D**a`I;?;sJ#Lb7o!%G+J&kW@3vHIbxE2W(qhyK*d8bwxkbE
z)X-8Sia9d{BP`*C2(n4CAgl4Ga%AOxnw+4bIWI9cH9r0pSA2YKeoAQ$h|LopUs#%$
z1Cc4>2jz88838T<ia?=N1oCB(B1i%h=wRnU2vANb0=WT{DvLp3#=*$J&A|+YEMS?C
zmlGHn7?ME>pyewAH@H@0U;r1e^FU1qh7?9gh6M}@8ERQ-SW=i$7+RSm8ERQ;7>o2u
z7;8Yy2<9}V8dh<JTDCla8fJ(XOr!>^JBulcxrVKeVF61G(?Z5t_7YZ*J{N{qn;52A
zj#|zd&IJr9j0+iSxoQ}T+G;p!*lRda7;3m$nNpb3m~9veV?bgMks1!T7<({-CW~JY
zs4jX5BJ?!bZm|@DT1A4O6vC2{Sdw^)IWM(5iZv}WCnx?EXKG$)Zfa3tNvbCIElyYl
zDvsg?=Y#mP%%b9wA{kIvN;5DpL~($!BFMZb-V#u=1C$w4GINWgIPyzN;=wE~5DQWu
z7Ab)Aut7LcY&ofU=_MJEr~t<dI6goL0~A@{+JS+QhmnVog;9t}fJux|fRT%dg;9h_
zgsBLmN|QMR?6V?OP|jun=l)x4`AJ!+$tB>(1FJ$ZA8yGl4jV}Nu>%DPxKQF?7LWx1
D!3%}D

literal 0
HcmV?d00001

diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95
GIT binary patch
literal 687
zcmWIL<>g{vU|^W>w;)NCk%8ech=Yuo85kHG7#J9ewHO!}QW#Pga~N_NqZk<(QkYVh
zTNt94Qdm-0TNt94Q@B$&QaD=}qgYb7f*CY<UV=<n$#{!1CqKQoG&i-PAejjy1jQgW
z8v_G_GstWq1_p*2h7yJvh8o5c#uTPrW(kI122Eza5KX3A%qf|<QS8~NsRb#SxkVs-
znoPGCD~ebc7#LPE6oCx+<)@#KpPQ<moRV6QUs_b6?^IfpnpmozlA2nOUX+-b6CdQ~
z7hj%Pk`bSinU|efq#s;bkXjUyn4X?mr0)Y2%FN5o$S=y%E2zB1l9ykU%L;ND$UX+f
zA`S)yhGa&l+d&kF4N?GhzX$^ZLkZaZVAnIHFsCr}G6geeviQ9ODb!>HyY?l>FinVI
znv6vt?O^9a2zCYrhFiSFP*=nkm!uYOBDsK@fq|h2B=8axCCMNH<PH!9g#?HXc8d}?
zENd908EQc;Vrph|VTk3aWiDZ?VQ6NoWhfM=0fjB2pC;2S)}q9`^wc7d(V8q^_e3$5
zB$nP{%`8bRD!#=S4{;eNh>O6^6M{Q0EiXPVKR&ghfEVOOQ1CD?axn5RaxoSOFfcIq
zX)=KwTg1b_z@W)`i?t-bC^@4D<VtW@g6#yWM6v>54~I=|eoARhsvRiAib3K$%p5En
FEC9&7gn0k}

literal 0
HcmV?d00001

diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95490adcd0f1a39a9d2181a1fe19b54b9ac7e4be
GIT binary patch
literal 571
zcmWIL<>g{vU|<kaFHG9c$iVOz#6iZ=3=9ko3=9m#b_@&*DGVu$DNHHMDJ&_hDQqe1
zDI6)BDO@SsDLg5>DSRpXDFP{iDMBg2DIzJNDPk$&DH17?DN=K!q8L)7gBdhsUV==1
z$;iOK@RA8cFoOsd5Wxx}*gyn3h~NMboFD?^x|iG_77vKv1rdB8f*(W(fCxblAp|0X
zL4*j15CsupAVM5ONPq}Q1_lOAsVJ7@ocN+&j4`(u^KLN~M6nmg7Zenw7Dq9sWF|(j
z7RQ%o6h<-Uq?Y_*)CX}h^U|W&isK6si;JUJ3JT&Aqu2@x;tLXsqF4(G;!{$==2b>9
zC*@Z}F_%=7++r+-n3-5q62)8`Ul_%dR}jTiP*BAh9qbwq3!?mkW2@MsgF_sHLO}cg
zN0(SXO~zXsi6!~D@wtfwD;bIu85khMFF*Z^{M=Oi<doEc{L-QleW%i*)WlN#l+@IM
z^rFPfocJI=zxeXZl8pGA%)IQ>BK_dfg4Cjr#PsykB7Gl-5X@k`g34PQHo5sJr8%i~
TpeQIdV_;xlVk8+pXZa5R$7ze`

literal 0
HcmV?d00001

diff --git a/SuperTagger/Linker/atom_map.py b/SuperTagger/Linker/atom_map.py
new file mode 100644
index 0000000..893fd00
--- /dev/null
+++ b/SuperTagger/Linker/atom_map.py
@@ -0,0 +1,28 @@
+atom_map = \
+    {'cl_r': 0,
+     '\\': 1,
+     'n': 2,
+     'p': 3,
+     's_ppres': 4,
+     'dia': 5,
+     's_whq': 6,
+     'let': 7,
+     '/': 8,
+     's_inf': 9,
+     's_pass': 10,
+     'pp_a': 11,
+     'pp_par': 12,
+     'pp_de': 13,
+     'cl_y': 14,
+     'box': 15,
+     'txt': 16,
+     's': 17,
+     's_ppart': 18,
+     's_q': 19,
+     'np': 20,
+     'pp': 21,
+     '[SEP]': 22,
+     '[SOS]': 23,
+     '[START]': 24,
+     '[PAD]': 25
+     }
diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index 49e702c..ddb8cb5 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -1,7 +1,30 @@
 import re
 
+from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
+from SuperTagger.Linker.atom_map import atom_map
 
-atoms_list = ['r', 'np']
+
+def get_atoms_from_category(category, category_to_atoms):
+    if category in atom_map.keys():
+        return [category]
+    else:
+        category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category)
+        left_side, right_side = category_cut.group(1), category_cut.group(2)
+
+        category_to_atoms += get_atoms_from_category(left_side, [])
+        category_to_atoms += get_atoms_from_category(right_side, [])
+
+        return category_to_atoms
+
+
+def get_atoms_batch(category_batch):
+    batch = []
+    for sentence in category_batch:
+        category_to_atoms = []
+        for category in sentence:
+            category_to_atoms = get_atoms_from_category(category, category_to_atoms)
+        batch.append(category_to_atoms)
+    return batch
 
 
 def cut_category_in_symbols(category):
@@ -11,10 +34,10 @@ def cut_category_in_symbols(category):
     Returns :
     Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
     '''
-    category_to_weights = []
+    category_to_polarity = []
 
-    if category in atoms_list:
-        category_to_weights.append(True)
+    if category in atom_map.keys():
+        category_to_polarity.append(True)
 
     else:
         # dr = /
@@ -23,16 +46,16 @@ def cut_category_in_symbols(category):
             left_side, right_side = category_cut.group(1), category_cut.group(2)
 
             # for the left side
-            if left_side in atoms_list:
-                category_to_weights.append(False)
+            if left_side in atom_map.keys():
+                category_to_polarity.append(False)
             else:
-                category_to_weights += cut_category_in_symbols(left_side)
+                category_to_polarity += cut_category_in_symbols(left_side)
 
             # for the right side
-            if right_side in atoms_list:
-                category_to_weights.append(True)
+            if right_side in atom_map.keys():
+                category_to_polarity.append(True)
             else:
-                category_to_weights += cut_category_in_symbols(right_side)
+                category_to_polarity += cut_category_in_symbols(right_side)
 
         # dl = \
         elif category.startswith("dl"):
@@ -40,21 +63,18 @@ def cut_category_in_symbols(category):
             left_side, right_side = category_cut.group(1), category_cut.group(2)
 
             # for the left side
-            if left_side in atoms_list:
-                category_to_weights.append(True)
+            if left_side in atom_map.keys():
+                category_to_polarity.append(True)
             else:
-                category_to_weights += cut_category_in_symbols(left_side)
+                category_to_polarity += cut_category_in_symbols(left_side)
 
             # for the right side
-            if right_side in atoms_list:
-                category_to_weights.append(False)
+            if right_side in atom_map.keys():
+                category_to_polarity.append(False)
             else:
-                category_to_weights += cut_category_in_symbols(right_side)
-
-    return category_to_weights
-
+                category_to_polarity += cut_category_in_symbols(right_side)
 
-print( cut_category_in_symbols('dr(1,dr(1,r,np),np)'))
+    return category_to_polarity
 
 
 def find_pos_neg_idexes(batch_symbols):
@@ -65,18 +85,11 @@ def find_pos_neg_idexes(batch_symbols):
     Returns :
     (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes
     '''
-    return None
-
-
-def make_sinkhorn_inputs(bsd_tensor, positional_ids):
-    """
-    :param bsd_tensor:
-        Tensor of shape (batch size, sequence length, feature dimensionality).
-    :param positional_ids:
-        A List (batch_size, max_atoms_in_sentence) .
-        Each positional_ids[b][a] indexes the location of atoms of type a in sentence b.
-    :return:
-    """
+    list_batch = []
+    for sentence in batch_symbols:
+        list_symbols = []
+        for category in sentence:
+            list_symbols.append(cut_category_in_symbols(category))
+        list_batch.append(list_symbols)
+    return list_batch
 
-    return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence]
-            for i, sentence in enumerate(positional_ids)]
\ No newline at end of file
diff --git a/test.py b/test.py
index d6882f3..f208027 100644
--- a/test.py
+++ b/test.py
@@ -1,7 +1,27 @@
-l = [[False, True, True, False],
-        [True, False, True, False]]
+from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+import torch
 
-print(l)
-print([i for i, x in enumerate(l) if x])
+atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"],
+               ["np", "np", "v", "v","np", "np", "v", "v"]]
 
-print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l)))
\ No newline at end of file
+atoms_polarity = [[False, True, True, False,False, True, True, False],
+                  [True, False, True, False,True, False, True, False]]
+
+atoms_encoding = torch.randn((2, 8, 24))
+
+matches = []
+for sentence_idx in range(len(atoms_polarity)):
+
+    for atom_type in ["np", "v"]:
+        pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                 x and atoms_batch[sentence_idx][i] == atom_type]
+        neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if
+                                 not x and atoms_batch[sentence_idx][i] == atom_type]
+
+        pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :]
+        neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :]
+
+        weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0))
+        matches.append(sinkhorn(weights, iters=3))
+
+print(matches)
-- 
GitLab