From c7baf521e43501d7b042717461444f6c426c2a2e Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Tue, 3 May 2022 17:34:31 +0200 Subject: [PATCH] progress on linker --- SuperTagger/Linker/AtomEmbedding.py | 12 +++ SuperTagger/Linker/AtomTokenizer.py | 51 +++++++++++ SuperTagger/Linker/Linker.py | 83 ++++++++++-------- .../__pycache__/AtomTokenizer.cpython-38.pyc | Bin 0 -> 2801 bytes .../__pycache__/Sinkhorn.cpython-38.pyc | Bin 0 -> 687 bytes .../__pycache__/atom_map.cpython-38.pyc | Bin 0 -> 571 bytes SuperTagger/Linker/atom_map.py | 28 ++++++ SuperTagger/Linker/utils.py | 81 ++++++++++------- test.py | 30 +++++-- 9 files changed, 211 insertions(+), 74 deletions(-) create mode 100644 SuperTagger/Linker/AtomEmbedding.py create mode 100644 SuperTagger/Linker/AtomTokenizer.py create mode 100644 SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc create mode 100644 SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc create mode 100644 SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc create mode 100644 SuperTagger/Linker/atom_map.py diff --git a/SuperTagger/Linker/AtomEmbedding.py b/SuperTagger/Linker/AtomEmbedding.py new file mode 100644 index 0000000..e7be599 --- /dev/null +++ b/SuperTagger/Linker/AtomEmbedding.py @@ -0,0 +1,12 @@ +import torch +from torch.nn import Module, Embedding + + +class AtomEmbedding(Module): + def __init__(self, dim_linker, atom_vocab_size, padding_idx=None): + super(AtomEmbedding, self).__init__() + self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx, + scale_grad_by_freq=True) + + def forward(self, x): + return self.emb(x) diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py new file mode 100644 index 0000000..e400d4e --- /dev/null +++ b/SuperTagger/Linker/AtomTokenizer.py @@ -0,0 +1,51 @@ +import torch + + +class AtomTokenizer(object): + def __init__(self, atom_map, max_atoms_in_sentence): + self.atom_map = atom_map + self.max_atoms_in_sentence = max_atoms_in_sentence + self.inverse_atom_map = {v: k for k, v in self.atom_map.items()} + self.sep_token = '[SEP]' + self.pad_token = '[PAD]' + self.sos_token = '[SOS]' + self.sep_token_id = self.atom_map[self.sep_token] + self.pad_token_id = self.atom_map[self.pad_token] + self.sos_token_id = self.atom_map[self.sos_token] + + def __len__(self): + return len(self.atom_map) + + def convert_atoms_to_ids(self, atom): + return self.atom_map[str(atom)] + + def convert_sents_to_ids(self, sentences): + return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences]) + + def convert_batchs_to_ids(self, batchs_sentences): + return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences], + max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id)) + + def convert_ids_to_atoms(self, ids): + return [self.inverse_atom_map[int(i)] for i in ids] + + +def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): + max_size = sequences[0].size() + trailing_dims = max_size[1:] + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor + diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 6568230..745b7d9 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -4,9 +4,11 @@ import torch from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU from Configuration import Configuration - +from SuperTagger.Linker.AtomEmbedding import AtomEmbedding +from SuperTagger.Linker.AtomTokenizer import AtomTokenizer +from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, make_sinkhorn_inputs +from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch def FFN(d_model, d_ff, dropout_rate=0.1, d_out=None): @@ -24,56 +26,67 @@ class Linker: self.dim_encoder = int(Configuration.modelDecoderConfig['dim_encoder']) self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) + self.dim_linker = int(Configuration.modelDecoderConfig['dim_linker']) + self.max_atoms_in_sentence = int(Configuration.modelDecoderConfig['max_atoms_in_sentence']) + self.atom_vocab_size = int(Configuration.modelDecoderConfig['atom_vocab_size']) self.dropout = Dropout(0.1) + self.atom_map = atom_map + self.padding_id = self.atom_map['[PAD]'] + self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.atom_embedding = AtomEmbedding(self.dim_linker, self.atom_vocab_size, self.padding_id) + + # to do : definit un encoding + self.linker_encoder = FFN(self.dim_linker, self.dim_linker, 0.1) + self.pos_transformation = Sequential( - FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), - LayerNorm(self.dim_decoder // 2, eps=1e-12) + FFN(self.dim_decoder, self.dim_decoder, 0.1), + LayerNorm(self.dim_decoder, eps=1e-12) ) self.neg_transformation = Sequential( - FFN(self.dim_decoder * 2, self.dim_decoder, 0.1, self.dim_decoder // 2), - LayerNorm(self.dim_decoder // 2, eps=1e-12) + FFN(self.dim_decoder, self.dim_decoder, 0.1), + LayerNorm(self.dim_decoder, eps=1e-12) ) - def forward(self, symbols_batch, symbols_decoding): + def forward(self, category_batch): ''' Parameters : symbols_decoding : batch of size (batch_size, sequence_length) = output of decoder + Retturns : + link_weights : batch-size, atom_vocab_size, ...) ''' - # some sequential for linker with output of decoder and initial ato - - # decompose into batch_size, max symbols in sentence - decompose_decoding = find_pos_neg_idexes(symbols_batch) - - # get tensors of shape (batch_size, max_symbols_in_sentence/2) - pos_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if x], decompose_decoding)) - neg_idxes_batch = list(map(lambda sub_list: [i for i, x in enumerate(sub_list) if not x], decompose_decoding)) - - _positives = make_sinkhorn_inputs(symbols_decoding, pos_idxes_batch) - _negatives = make_sinkhorn_inputs(symbols_decoding, neg_idxes_batch) + # atoms embedding + atoms_batch = get_atoms_batch(category_batch) + atoms_batch = self.atom_tokenizer.convert_batchs_to_ids(atoms_batch) + atoms_embedding = self.atom_embedding(atoms_batch) - positives = [tensor for tensor in chain.from_iterable(_positives) if min(tensor.size()) != 0] - negatives = [tensor for tensor in chain.from_iterable(_negatives) if min(tensor.size()) != 0] + # MHA ou LSTM avec sortie de BERT + # + # TO DO + # atoms_encoding = self.linker_encoder(atoms_embedding) + # + atoms_encoding = atoms_embedding - distinct_shapes = {tensor.size()[0] for tensor in positives} - distinct_shapes = sorted(distinct_shapes) + # find atoms polarity : list (not tensor) (batch_size, max_atoms_in sentence) + atoms_polarity = find_pos_neg_idexes(category_batch) - # going to match the pos and neg together - matches = [] + link_weights = [] + for sentence_idx in range(len(atoms_polarity)): + for atom_type in self.atom_map.keys(): + pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + x and atoms_batch[sentence_idx][i] == atom_type] + neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + not x and atoms_batch[sentence_idx][i] == atom_type] - all_shape_positives = [self.pos_transformation(self.dropout(torch.stack([tensor for tensor in positives - if tensor.size()[0] == shape]))) - for shape in distinct_shapes] + pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] + neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] - all_shape_negatives = [self.neg_transformation(self.dropout(torch.stack([tensor for tensor in negatives - if tensor.size()[0] == shape]))) - for shape in distinct_shapes] + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) - for this_shape_positives, this_shape_negatives in zip(all_shape_positives, all_shape_negatives): - weights = torch.bmm(this_shape_positives, - this_shape_negatives.transpose(2, 1)) - matches.append(sinkhorn(weights, iters=3)) + weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) + link_weights.append(sinkhorn(weights, iters=3)) - return matches + return link_weights diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb55c03f688748485e4452c56dda80e12c73c904 GIT binary patch literal 2801 zcmWIL<>g{vU|^8YEKCY#XJB{?;vi!d1_lNP1_p*=H3kNT6owSW9EK<ccZL+E6y_F& z6y{XMX67iS6pj?O6!sRzDCQK#U<OUjmmm}Tl0hUgW`dbz!@$6h$`Hkv!Vtxj!kEI; z!Vtxr!kog=!Vtxh!kWU?!Vtxp!k)s>!Vtxl!kNO=!Vtxt!kxm?!Vtxg!W+z>$#;v_ zu_Qk?BtJVfFS9DOD47|_9uOOZok7k{VqjpXVXR?@XQ*MSVTfmhSXaZ4#jt>BA%hD; zGf0j(g}H_yo+X8)h9RCcg|&tuo(-g;h9QeZlA(qno*l|#1@kzdJT@?oGle0TL6gld znGwwippXj&IYEqpfgzKjh9Q<OhN+gZmZ^rRhH)}eAxkjBN`@kkLp7OiG3goHV$8n9 zSOyjV6TgD>GxBp&^^;Rl3-U{gO7xvdi&7Iy^;1$)3(|`cGjrmD{QTm}GfOh!b29U? zQ;YP2OAAtqLK4%{Q;YO{AVMgiq*qXRi_0b@Gr1%=Kexb+1LQb11_lN;1~vvJhAJ&o z#qsf(d6^~g@p?8n`N@en#ddlKtyQei!L9+ZRjkngjxHc7*grT{lj{~oVo82(d~RaF zEz#V>ig*yOI6gBkzBo0nBsDKN^%iSpNosEKErHCuvecsDRImz=`Y6ug)PndDP<%#l z79^&?n8o?UP$mytL40P)Egra<_{@|j9=Njj%oI)LTP($?IcY_l3=9lKpwtZxP6)xy zz`$^e1LieWPzVTt;*F7wk&RIdOtLZZG4e51iD3^7lxPM;IVe$sQyR#o5{4RvEXHQW zV1|{9ews|Tm~&F|!0uSdSi}c%4k)^jT*4k7pOcywA1?rM2*@M`#wt;Ci@_?A!LlHy zGB7ZJ*dUX@_JYg>rN|P-8pdWuNrqsCB9L*KOkj&|F&CE<X)=Lrzr_NI5Qz04?M2|w z5J}DlMMeoYE<m9jpP5oDjN~9CR0m;Gj^R2`$bzgWmSA9DC}CK@*v^m!$$K0%jG$b% z5L8ewqvjfrw?G=fmV<IlIyl$VGIlU5V5nhS$T$(4Yrr1RWGn)y&}1qC<+CDDP~d`M z5-bVv7?({>W-%nGi-WuaQV34yX4pLhN~AFVBPDvcA(~9LSWEJYk~3~`CKkt+q~;ap z!$KC6d?6{k7~%l1Q&58!*?5rCi$UIHV64IsyvWKif)`{mICw#&GAMXUn9u{bgCUE# zgCUD0g}Ij*9M~)r1-29e0|PX$qgcUiD+2ow<O*<b7lFbUq^cMm&gR&|IVrIuIfFnz zvxAC{oYcHqyakCVDVcfc@nwlQrKy^%MW6&#BnisK;4+{%wXhUavOqFAFT|CgED<FD zwFl%_kP3{L5rw5>sKSzbP|61xT?{H385qSFSr}OUS1A+W4U8B9g$oE5gF*>CZh{#! z@g;pwD1wu|5Gd(qF~D**a`I;?;sJ#Lb7o!%G+J&kW@3vHIbxE2W(qhyK*d8bwxkbE z)X-8Sia9d{BP`*C2(n4CAgl4Ga%AOxnw+4bIWI9cH9r0pSA2YKeoAQ$h|LopUs#%$ z1Cc4>2jz88838T<ia?=N1oCB(B1i%h=wRnU2vANb0=WT{DvLp3#=*$J&A|+YEMS?C zmlGHn7?ME>pyewAH@H@0U;r1e^FU1qh7?9gh6M}@8ERQ-SW=i$7+RSm8ERQ;7>o2u z7;8Yy2<9}V8dh<JTDCla8fJ(XOr!>^JBulcxrVKeVF61G(?Z5t_7YZ*J{N{qn;52A zj#|zd&IJr9j0+iSxoQ}T+G;p!*lRda7;3m$nNpb3m~9veV?bgMks1!T7<({-CW~JY zs4jX5BJ?!bZm|@DT1A4O6vC2{Sdw^)IWM(5iZv}WCnx?EXKG$)Zfa3tNvbCIElyYl zDvsg?=Y#mP%%b9wA{kIvN;5DpL~($!BFMZb-V#u=1C$w4GINWgIPyzN;=wE~5DQWu z7Ab)Aut7LcY&ofU=_MJEr~t<dI6goL0~A@{+JS+QhmnVog;9t}fJux|fRT%dg;9h_ zgsBLmN|QMR?6V?OP|jun=l)x4`AJ!+$tB>(1FJ$ZA8yGl4jV}Nu>%DPxKQF?7LWx1 D!3%}D literal 0 HcmV?d00001 diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95 GIT binary patch literal 687 zcmWIL<>g{vU|^W>w;)NCk%8ech=Yuo85kHG7#J9ewHO!}QW#Pga~N_NqZk<(QkYVh zTNt94Qdm-0TNt94Q@B$&QaD=}qgYb7f*CY<UV=<n$#{!1CqKQoG&i-PAejjy1jQgW z8v_G_GstWq1_p*2h7yJvh8o5c#uTPrW(kI122Eza5KX3A%qf|<QS8~NsRb#SxkVs- znoPGCD~ebc7#LPE6oCx+<)@#KpPQ<moRV6QUs_b6?^IfpnpmozlA2nOUX+-b6CdQ~ z7hj%Pk`bSinU|efq#s;bkXjUyn4X?mr0)Y2%FN5o$S=y%E2zB1l9ykU%L;ND$UX+f zA`S)yhGa&l+d&kF4N?GhzX$^ZLkZaZVAnIHFsCr}G6geeviQ9ODb!>HyY?l>FinVI znv6vt?O^9a2zCYrhFiSFP*=nkm!uYOBDsK@fq|h2B=8axCCMNH<PH!9g#?HXc8d}? zENd908EQc;Vrph|VTk3aWiDZ?VQ6NoWhfM=0fjB2pC;2S)}q9`^wc7d(V8q^_e3$5 zB$nP{%`8bRD!#=S4{;eNh>O6^6M{Q0EiXPVKR&ghfEVOOQ1CD?axn5RaxoSOFfcIq zX)=KwTg1b_z@W)`i?t-bC^@4D<VtW@g6#yWM6v>54~I=|eoARhsvRiAib3K$%p5En FEC9&7gn0k} literal 0 HcmV?d00001 diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95490adcd0f1a39a9d2181a1fe19b54b9ac7e4be GIT binary patch literal 571 zcmWIL<>g{vU|<kaFHG9c$iVOz#6iZ=3=9ko3=9m#b_@&*DGVu$DNHHMDJ&_hDQqe1 zDI6)BDO@SsDLg5>DSRpXDFP{iDMBg2DIzJNDPk$&DH17?DN=K!q8L)7gBdhsUV==1 z$;iOK@RA8cFoOsd5Wxx}*gyn3h~NMboFD?^x|iG_77vKv1rdB8f*(W(fCxblAp|0X zL4*j15CsupAVM5ONPq}Q1_lOAsVJ7@ocN+&j4`(u^KLN~M6nmg7Zenw7Dq9sWF|(j z7RQ%o6h<-Uq?Y_*)CX}h^U|W&isK6si;JUJ3JT&Aqu2@x;tLXsqF4(G;!{$==2b>9 zC*@Z}F_%=7++r+-n3-5q62)8`Ul_%dR}jTiP*BAh9qbwq3!?mkW2@MsgF_sHLO}cg zN0(SXO~zXsi6!~D@wtfwD;bIu85khMFF*Z^{M=Oi<doEc{L-QleW%i*)WlN#l+@IM z^rFPfocJI=zxeXZl8pGA%)IQ>BK_dfg4Cjr#PsykB7Gl-5X@k`g34PQHo5sJr8%i~ TpeQIdV_;xlVk8+pXZa5R$7ze` literal 0 HcmV?d00001 diff --git a/SuperTagger/Linker/atom_map.py b/SuperTagger/Linker/atom_map.py new file mode 100644 index 0000000..893fd00 --- /dev/null +++ b/SuperTagger/Linker/atom_map.py @@ -0,0 +1,28 @@ +atom_map = \ + {'cl_r': 0, + '\\': 1, + 'n': 2, + 'p': 3, + 's_ppres': 4, + 'dia': 5, + 's_whq': 6, + 'let': 7, + '/': 8, + 's_inf': 9, + 's_pass': 10, + 'pp_a': 11, + 'pp_par': 12, + 'pp_de': 13, + 'cl_y': 14, + 'box': 15, + 'txt': 16, + 's': 17, + 's_ppart': 18, + 's_q': 19, + 'np': 20, + 'pp': 21, + '[SEP]': 22, + '[SOS]': 23, + '[START]': 24, + '[PAD]': 25 + } diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index 49e702c..ddb8cb5 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -1,7 +1,30 @@ import re +from SuperTagger.Linker.AtomTokenizer import AtomTokenizer +from SuperTagger.Linker.atom_map import atom_map -atoms_list = ['r', 'np'] + +def get_atoms_from_category(category, category_to_atoms): + if category in atom_map.keys(): + return [category] + else: + category_cut = re.search(r'\w*\(\d+,(.+),(.+)\)', category) + left_side, right_side = category_cut.group(1), category_cut.group(2) + + category_to_atoms += get_atoms_from_category(left_side, []) + category_to_atoms += get_atoms_from_category(right_side, []) + + return category_to_atoms + + +def get_atoms_batch(category_batch): + batch = [] + for sentence in category_batch: + category_to_atoms = [] + for category in sentence: + category_to_atoms = get_atoms_from_category(category, category_to_atoms) + batch.append(category_to_atoms) + return batch def cut_category_in_symbols(category): @@ -11,10 +34,10 @@ def cut_category_in_symbols(category): Returns : Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes ''' - category_to_weights = [] + category_to_polarity = [] - if category in atoms_list: - category_to_weights.append(True) + if category in atom_map.keys(): + category_to_polarity.append(True) else: # dr = / @@ -23,16 +46,16 @@ def cut_category_in_symbols(category): left_side, right_side = category_cut.group(1), category_cut.group(2) # for the left side - if left_side in atoms_list: - category_to_weights.append(False) + if left_side in atom_map.keys(): + category_to_polarity.append(False) else: - category_to_weights += cut_category_in_symbols(left_side) + category_to_polarity += cut_category_in_symbols(left_side) # for the right side - if right_side in atoms_list: - category_to_weights.append(True) + if right_side in atom_map.keys(): + category_to_polarity.append(True) else: - category_to_weights += cut_category_in_symbols(right_side) + category_to_polarity += cut_category_in_symbols(right_side) # dl = \ elif category.startswith("dl"): @@ -40,21 +63,18 @@ def cut_category_in_symbols(category): left_side, right_side = category_cut.group(1), category_cut.group(2) # for the left side - if left_side in atoms_list: - category_to_weights.append(True) + if left_side in atom_map.keys(): + category_to_polarity.append(True) else: - category_to_weights += cut_category_in_symbols(left_side) + category_to_polarity += cut_category_in_symbols(left_side) # for the right side - if right_side in atoms_list: - category_to_weights.append(False) + if right_side in atom_map.keys(): + category_to_polarity.append(False) else: - category_to_weights += cut_category_in_symbols(right_side) - - return category_to_weights - + category_to_polarity += cut_category_in_symbols(right_side) -print( cut_category_in_symbols('dr(1,dr(1,r,np),np)')) + return category_to_polarity def find_pos_neg_idexes(batch_symbols): @@ -65,18 +85,11 @@ def find_pos_neg_idexes(batch_symbols): Returns : (batch_size, max_symbols_in_sentence) boolean tensor indiating pos and ne indexes ''' - return None - - -def make_sinkhorn_inputs(bsd_tensor, positional_ids): - """ - :param bsd_tensor: - Tensor of shape (batch size, sequence length, feature dimensionality). - :param positional_ids: - A List (batch_size, max_atoms_in_sentence) . - Each positional_ids[b][a] indexes the location of atoms of type a in sentence b. - :return: - """ + list_batch = [] + for sentence in batch_symbols: + list_symbols = [] + for category in sentence: + list_symbols.append(cut_category_in_symbols(category)) + list_batch.append(list_symbols) + return list_batch - return [[bsd_tensor.select(0, index=i).index_select(0, index=atom) for atom in sentence] - for i, sentence in enumerate(positional_ids)] \ No newline at end of file diff --git a/test.py b/test.py index d6882f3..f208027 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,27 @@ -l = [[False, True, True, False], - [True, False, True, False]] +from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +import torch -print(l) -print([i for i, x in enumerate(l) if x]) +atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"], + ["np", "np", "v", "v","np", "np", "v", "v"]] -print(list(map(lambda sub_list : [i for i, x in enumerate(sub_list) if x], l))) \ No newline at end of file +atoms_polarity = [[False, True, True, False,False, True, True, False], + [True, False, True, False,True, False, True, False]] + +atoms_encoding = torch.randn((2, 8, 24)) + +matches = [] +for sentence_idx in range(len(atoms_polarity)): + + for atom_type in ["np", "v"]: + pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + x and atoms_batch[sentence_idx][i] == atom_type] + neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if + not x and atoms_batch[sentence_idx][i] == atom_type] + + pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] + neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] + + weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) + matches.append(sinkhorn(weights, iters=3)) + +print(matches) -- GitLab