diff --git a/trytorch/gen_anchors.py b/trytorch/gen_anchors.py index 7204d1bab840f4d1dbf8b7c822b43882cc7d9277..393fc755eb8342775dd160588ab086ff9a5591f5 100644 --- a/trytorch/gen_anchors.py +++ b/trytorch/gen_anchors.py @@ -1,23 +1,17 @@ import numpy as np import sys import os -from transformers import BertTokenizer, BertModel -from transformers.tokenization_utils_base import BatchEncoding +from transformers import BertModel import torch -from torch import nn -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader -#from torchmetrics import F1Score -#from torch.autograd import Variable from tqdm import tqdm -from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch +from train_model_baseline import generate_sentence_list, collate_batch from allennlp.data import vocabulary from sklearn.decomposition import PCA import matplotlib.pyplot as plt -import pylab +import argparse bert = 'bert-base-multilingual-cased' -tokenizer = BertTokenizer.from_pretrained(bert) def anchors(corpus, sset, fmt, vocab_file): vocab = vocabulary.Vocabulary() @@ -42,7 +36,7 @@ def anchors(corpus, sset, fmt, vocab_file): bert_embeddings = BertModel.from_pretrained(bert) #words for which to collect contextual embeddings in order to plot point clouds - clouds = ['the', 'who'] + clouds = ['le', 'qui'] cloud_embeddings = {} for cloud in clouds: cloud_embeddings[cloud] = [] @@ -68,9 +62,10 @@ def anchors(corpus, sset, fmt, vocab_file): pca = PCA(n_components=2) anchors = anchors.detach().numpy() - save_embs(anchors, vocab, corpus, sset, fmt) + #save_embs(anchors, vocab, corpus, sset, fmt) #plot_anchors(anchors, pca, vocab, oov_id) #plot_clouds(cloud_embeddings, anchors, pca, vocab) + map_anchors(cloud_embeddings) #, anchors, vocab) def save_embs(anchors, vocab, corpus, sset, fmt): #words not in the corpus are not included @@ -119,7 +114,7 @@ def plot_clouds(cloud_embeddings, anchors, pca, vocab): embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]]) nb0 = embs0.shape[0] nb1 = embs1.shape[0] - if nb0 == 0 or nb1 == 1: + if nb0 == 0 or nb1 == 0: print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.') sys.exit() @@ -141,19 +136,41 @@ def plot_clouds(cloud_embeddings, anchors, pca, vocab): plt.savefig('clouds.png') plt.clf() +def map_anchors(cloud_embeddings): #, anchors, vocab): + tok = list(cloud_embeddings.keys())[0] + embs = cloud_embeddings[tok] + n = 10 + if len(embs) > n: + embs = embs[:n] + else: + n = len(embs) + + fig, axs = plt.subplots(nrows=1, ncols=n, subplot_kw={'xticks': [], 'yticks': []}) + + for i, ax in enumerate(axs.flat): + emb = embs[i].detach().numpy().reshape(768,1) + ax.imshow(emb, cmap='Blues', aspect='auto') + ax.set_title(f'({i+1})') + + plt.title(f"A few contextual embeddings for the word '{tok}'") + plt.tight_layout() + plt.show() + plt.savefig('heatmap_anchors.png') + plt.clf() + def main(): - if len(sys.argv) != 5: - print("usage: gen_anchors.py <corpus> <train/test/dev/all> <conllu/tok> <vocab_file>") - sys.exit() - corpus = sys.argv[1] - sset = sys.argv[2] - fmt = sys.argv[3] - vocab_file = sys.argv[4] - vocab_file = os.path.join('dictionaries', 'monolingual', vocab_file) + parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list') + parser.add_argument('--corpus', required=True, help='corpus for which to compute anchors') + parser.add_argument('--set', default='train', help='portion of the corpus to test on (train/test/dev/all)') + parser.add_argument('--format', default='conllu', help='tok or conllu') + parser.add_argument('--voc', required=True, help='words for which to compute anchors (a text file with one word per line and including <UNK> token)') + params = parser.parse_args() + + vocab_file = os.path.join('dictionaries', 'monolingual', params.voc) if not os.path.isfile(vocab_file): print(f'no vocab file at {vocab_file}.') sys.exit() - anchors(corpus, sset, fmt, vocab_file) + anchors(params.corpus, params.set, params.format, vocab_file) if __name__ == '__main__': main() diff --git a/trytorch/plot_mapping.py b/trytorch/plot_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..ada204cbd527515703e96bb7197e25d144754852 --- /dev/null +++ b/trytorch/plot_mapping.py @@ -0,0 +1,114 @@ +import numpy as np +import sys +import os +from transformers import BertModel +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from train_model_baseline import generate_sentence_list, collate_batch +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import argparse + +def mapping(src_corpus, tgt_corpus, mapping, sset, fmt): + + data = {} + mapping = torch.load(mapping) + mapp = torch.from_numpy(mapping) #TEST + print("MAPPING", mapping.shape) + + if sset == 'all': + ssets = ['train', 'test', 'dev'] + for corpus in [src_corpus, tgt_corpus]: + for s in ssets: + data[corpus][s] = generate_sentence_list(corpus, s, fmt) + data[corpus] = data[corpus]['train'] + data[corpus]['test'] + data[corpus]['dev'] + else: + for corpus in [src_corpus, tgt_corpus]: + data[corpus] = generate_sentence_list(corpus, sset, fmt) + + data = data[src_corpus] + data[tgt_corpus] + batch_size = 64 + + dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch) + bert = 'bert-base-multilingual-cased' + bert_embeddings = BertModel.from_pretrained(bert) + + #words for which to collect contextual embeddings in order to plot point clouds + clouds = ['the', 'who', 'le', 'qui'] + clouds_fr = clouds[2:] + cloud_embeddings = {} + aligned_embeddings = {} + for cloud in clouds: + cloud_embeddings[cloud] = [] + for cloud in clouds_fr: + aligned_embeddings[cloud] = [] + + #write this as function + for sentence_batch in tqdm(dataloader): + bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state + + for i, sentence in enumerate(sentence_batch.tokens): + bert_sentence_output = bert_output[i] + + for j, token in enumerate(sentence): + bert_token_output = bert_sentence_output[j] + if token in clouds: + cloud_embeddings[token].append(bert_token_output) + if token in clouds_fr: + aligned_emb = np.matmul(bert_token_output.detach().numpy(), mapping.transpose()) + aligned_embeddings[token].append(aligned_emb) + + pca = PCA(n_components=2) + plot_clouds(cloud_embeddings, pca, 'before') + for cloud in clouds[:2]: + aligned_embeddings[cloud] = cloud_embeddings[cloud] #add unchanged target vectors + plot_clouds(aligned_embeddings, pca, 'After') + +def plot_clouds(cloud_embeddings, pca, text): + tok_en0, tok_en1, tok_fr0, tok_fr1 = cloud_embeddings.keys() + colors = ['b', 'c', 'm', 'r'] + embs_en0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en0]]) + embs_en1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en1]]) + embs_fr0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr0]]) + embs_fr1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr1]]) + + n_en0 = embs_en0.shape[0] + n_en1 = embs_en1.shape[1] + n_fr0 = embs_fr0.shape[0] + n_fr1 = enbs_fr1.shape[1] + + full_embs = np.concatenate((embs_en0, embs_en1, embs_fr0, embs_fr1), axis=0) + full_embs_reduced = pca.fit_transform(full_embs) + + for emb in full_embs_reduced[:n_en0]: + plt.plot(emb[0], emb[1], '.', color=colors[0]) + for emb in full_embs_reduced[n_en0:(n_en0+n_en1)]: + plt.plot(emb[0], emb[1], '.', color=colors[1]) + for emb in full_embs_reduced[(n_en0+n_en1):(n_en0+n_en1+n_fr0)]: + plt.plot(emb[0], emb[1], '.', color=colors[2]) + for emb in full_embs_reduced[(n_en0+n_en1+n_fr0):]: + plt.plot(emb[0], emb[1], '.', color=colors[3]) + + plt.title(f'{text} alignment') + plt.legend() + plt.show() + plt.savefig(f'{text}.png') + plt.clf() + +def main(): + parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list') + parser.add_argument('--src_corpus', required=True, help='corpus to align') + parser.add_argument('--tgt_corpus', required=True, help='corpus on which to align') + parser.add_argument('--mapping', required=True, help='path to .pth mapping file') + parser.add_argument('--set', default='train', help='portion of the corpus to test on (train/test/dev/all)') + parser.add_argument('--format', default='conllu', help='tok or conllu') + params = parser.parse_args() + + if not os.path.isfile(params.mapping): + print(f'no file at {params.mapping}.') + sys.exit() + mapping(params.src_corpus, params.tgt_corpus, params.mapping, params.set, params.format) + +if __name__ == '__main__': + main() diff --git a/trytorch/test_model.py b/trytorch/test_model.py index df60cfcbe008aa6ca28bf679a80a44e1a211b73c..4eefc23f53518740573ea9e73d2e758f56c7d20a 100644 --- a/trytorch/test_model.py +++ b/trytorch/test_model.py @@ -93,7 +93,7 @@ def main(): parser.add_argument('--type', default='baseline', help="baseline or rich model") parser.add_argument('--corpus', help='corpus to test on') parser.add_argument('--set', default='test', help='portion of the corpus to test on') - parser.add_argument('--errors', default=0, help='number of prediction errors to display on standard output') + parser.add_argument('--errors', type=int, default=0, help='number of prediction errors to display on standard output') params = parser.parse_args() diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py index b6038b51e56698c134c87b1aa2a6fddaccad3066..c772b818c5f0b950ad6c729a7d919224e407356a 100644 --- a/trytorch/train_model_baseline.py +++ b/trytorch/train_model_baseline.py @@ -48,8 +48,15 @@ class SentenceBatch(): self.tok_types = pad_sequence(tok_types, batch_first=True) self.tok_masks = pad_sequence(tok_masks, batch_first=True) self.labels = pad_sequence(labels, batch_first=True) - self.uposes = pad_sequence(upos, batch_first=True) - self.dheads = pad_sequence(dheads, batch_first=True) + if uposes is not None: + self.uposes = pad_sequence(uposes, batch_first=True) + else: self.uposes = None + if deprels is not None: + self.deprels = pad_sequence(deprels, batch_first=True) + else: self.deprels = None + if dheads is not None: + self.dheads = pad_sequence(dheads, batch_first=True) + else: self.dheads = None self.labels = pad_sequence(labels, batch_first=True) def getBatchEncoding(self):