Skip to content
Snippets Groups Projects
Commit f6e16e8f authored by Alice Pain's avatar Alice Pain
Browse files

minor

parent 5fbd0a05
No related branches found
No related tags found
No related merge requests found
import numpy as np
import sys
import os
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
from transformers import BertModel
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch
from train_model_baseline import generate_sentence_list, collate_batch
from allennlp.data import vocabulary
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pylab
import argparse
bert = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert)
def anchors(corpus, sset, fmt, vocab_file):
vocab = vocabulary.Vocabulary()
......@@ -42,7 +36,7 @@ def anchors(corpus, sset, fmt, vocab_file):
bert_embeddings = BertModel.from_pretrained(bert)
#words for which to collect contextual embeddings in order to plot point clouds
clouds = ['the', 'who']
clouds = ['le', 'qui']
cloud_embeddings = {}
for cloud in clouds:
cloud_embeddings[cloud] = []
......@@ -68,9 +62,10 @@ def anchors(corpus, sset, fmt, vocab_file):
pca = PCA(n_components=2)
anchors = anchors.detach().numpy()
save_embs(anchors, vocab, corpus, sset, fmt)
#save_embs(anchors, vocab, corpus, sset, fmt)
#plot_anchors(anchors, pca, vocab, oov_id)
#plot_clouds(cloud_embeddings, anchors, pca, vocab)
map_anchors(cloud_embeddings) #, anchors, vocab)
def save_embs(anchors, vocab, corpus, sset, fmt):
#words not in the corpus are not included
......@@ -119,7 +114,7 @@ def plot_clouds(cloud_embeddings, anchors, pca, vocab):
embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]])
nb0 = embs0.shape[0]
nb1 = embs1.shape[0]
if nb0 == 0 or nb1 == 1:
if nb0 == 0 or nb1 == 0:
print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.')
sys.exit()
......@@ -141,19 +136,41 @@ def plot_clouds(cloud_embeddings, anchors, pca, vocab):
plt.savefig('clouds.png')
plt.clf()
def map_anchors(cloud_embeddings): #, anchors, vocab):
tok = list(cloud_embeddings.keys())[0]
embs = cloud_embeddings[tok]
n = 10
if len(embs) > n:
embs = embs[:n]
else:
n = len(embs)
fig, axs = plt.subplots(nrows=1, ncols=n, subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axs.flat):
emb = embs[i].detach().numpy().reshape(768,1)
ax.imshow(emb, cmap='Blues', aspect='auto')
ax.set_title(f'({i+1})')
plt.title(f"A few contextual embeddings for the word '{tok}'")
plt.tight_layout()
plt.show()
plt.savefig('heatmap_anchors.png')
plt.clf()
def main():
if len(sys.argv) != 5:
print("usage: gen_anchors.py <corpus> <train/test/dev/all> <conllu/tok> <vocab_file>")
sys.exit()
corpus = sys.argv[1]
sset = sys.argv[2]
fmt = sys.argv[3]
vocab_file = sys.argv[4]
vocab_file = os.path.join('dictionaries', 'monolingual', vocab_file)
parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list')
parser.add_argument('--corpus', required=True, help='corpus for which to compute anchors')
parser.add_argument('--set', default='train', help='portion of the corpus to test on (train/test/dev/all)')
parser.add_argument('--format', default='conllu', help='tok or conllu')
parser.add_argument('--voc', required=True, help='words for which to compute anchors (a text file with one word per line and including <UNK> token)')
params = parser.parse_args()
vocab_file = os.path.join('dictionaries', 'monolingual', params.voc)
if not os.path.isfile(vocab_file):
print(f'no vocab file at {vocab_file}.')
sys.exit()
anchors(corpus, sset, fmt, vocab_file)
anchors(params.corpus, params.set, params.format, vocab_file)
if __name__ == '__main__':
main()
import numpy as np
import sys
import os
from transformers import BertModel
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from train_model_baseline import generate_sentence_list, collate_batch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import argparse
def mapping(src_corpus, tgt_corpus, mapping, sset, fmt):
data = {}
mapping = torch.load(mapping)
mapp = torch.from_numpy(mapping) #TEST
print("MAPPING", mapping.shape)
if sset == 'all':
ssets = ['train', 'test', 'dev']
for corpus in [src_corpus, tgt_corpus]:
for s in ssets:
data[corpus][s] = generate_sentence_list(corpus, s, fmt)
data[corpus] = data[corpus]['train'] + data[corpus]['test'] + data[corpus]['dev']
else:
for corpus in [src_corpus, tgt_corpus]:
data[corpus] = generate_sentence_list(corpus, sset, fmt)
data = data[src_corpus] + data[tgt_corpus]
batch_size = 64
dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch)
bert = 'bert-base-multilingual-cased'
bert_embeddings = BertModel.from_pretrained(bert)
#words for which to collect contextual embeddings in order to plot point clouds
clouds = ['the', 'who', 'le', 'qui']
clouds_fr = clouds[2:]
cloud_embeddings = {}
aligned_embeddings = {}
for cloud in clouds:
cloud_embeddings[cloud] = []
for cloud in clouds_fr:
aligned_embeddings[cloud] = []
#write this as function
for sentence_batch in tqdm(dataloader):
bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state
for i, sentence in enumerate(sentence_batch.tokens):
bert_sentence_output = bert_output[i]
for j, token in enumerate(sentence):
bert_token_output = bert_sentence_output[j]
if token in clouds:
cloud_embeddings[token].append(bert_token_output)
if token in clouds_fr:
aligned_emb = np.matmul(bert_token_output.detach().numpy(), mapping.transpose())
aligned_embeddings[token].append(aligned_emb)
pca = PCA(n_components=2)
plot_clouds(cloud_embeddings, pca, 'before')
for cloud in clouds[:2]:
aligned_embeddings[cloud] = cloud_embeddings[cloud] #add unchanged target vectors
plot_clouds(aligned_embeddings, pca, 'After')
def plot_clouds(cloud_embeddings, pca, text):
tok_en0, tok_en1, tok_fr0, tok_fr1 = cloud_embeddings.keys()
colors = ['b', 'c', 'm', 'r']
embs_en0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en0]])
embs_en1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en1]])
embs_fr0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr0]])
embs_fr1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr1]])
n_en0 = embs_en0.shape[0]
n_en1 = embs_en1.shape[1]
n_fr0 = embs_fr0.shape[0]
n_fr1 = enbs_fr1.shape[1]
full_embs = np.concatenate((embs_en0, embs_en1, embs_fr0, embs_fr1), axis=0)
full_embs_reduced = pca.fit_transform(full_embs)
for emb in full_embs_reduced[:n_en0]:
plt.plot(emb[0], emb[1], '.', color=colors[0])
for emb in full_embs_reduced[n_en0:(n_en0+n_en1)]:
plt.plot(emb[0], emb[1], '.', color=colors[1])
for emb in full_embs_reduced[(n_en0+n_en1):(n_en0+n_en1+n_fr0)]:
plt.plot(emb[0], emb[1], '.', color=colors[2])
for emb in full_embs_reduced[(n_en0+n_en1+n_fr0):]:
plt.plot(emb[0], emb[1], '.', color=colors[3])
plt.title(f'{text} alignment')
plt.legend()
plt.show()
plt.savefig(f'{text}.png')
plt.clf()
def main():
parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list')
parser.add_argument('--src_corpus', required=True, help='corpus to align')
parser.add_argument('--tgt_corpus', required=True, help='corpus on which to align')
parser.add_argument('--mapping', required=True, help='path to .pth mapping file')
parser.add_argument('--set', default='train', help='portion of the corpus to test on (train/test/dev/all)')
parser.add_argument('--format', default='conllu', help='tok or conllu')
params = parser.parse_args()
if not os.path.isfile(params.mapping):
print(f'no file at {params.mapping}.')
sys.exit()
mapping(params.src_corpus, params.tgt_corpus, params.mapping, params.set, params.format)
if __name__ == '__main__':
main()
......@@ -93,7 +93,7 @@ def main():
parser.add_argument('--type', default='baseline', help="baseline or rich model")
parser.add_argument('--corpus', help='corpus to test on')
parser.add_argument('--set', default='test', help='portion of the corpus to test on')
parser.add_argument('--errors', default=0, help='number of prediction errors to display on standard output')
parser.add_argument('--errors', type=int, default=0, help='number of prediction errors to display on standard output')
params = parser.parse_args()
......
......@@ -48,8 +48,15 @@ class SentenceBatch():
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
self.labels = pad_sequence(labels, batch_first=True)
self.uposes = pad_sequence(upos, batch_first=True)
self.dheads = pad_sequence(dheads, batch_first=True)
if uposes is not None:
self.uposes = pad_sequence(uposes, batch_first=True)
else: self.uposes = None
if deprels is not None:
self.deprels = pad_sequence(deprels, batch_first=True)
else: self.deprels = None
if dheads is not None:
self.dheads = pad_sequence(dheads, batch_first=True)
else: self.dheads = None
self.labels = pad_sequence(labels, batch_first=True)
def getBatchEncoding(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment