Skip to content
Snippets Groups Projects
Commit fd7bfe7f authored by Alice Pain's avatar Alice Pain
Browse files

anchors

parent b0a2c22a
Branches
No related tags found
No related merge requests found
......@@ -65,8 +65,8 @@
"lr": 0.001
},
"num_serialized_models_to_keep": 3,
"num_epochs": 10,
"num_epochs": 4,
"grad_norm": 5.0,
"cuda_device": -1
"cuda_device": 0
}
}
import os
import sys
def to_mono(dico, keep):
output_dir = os.path.join('dictionaries', 'monolingual')
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
input_file = os.path.join('dictionaries', dico)
beg = 3 if keep else 0
end = 5 if keep else 2
output_file = os.path.join(output_dir, (dico[beg:end] + "_" + dico))
words = set()
with open(input_file, "r") as rf:
for line in rf:
split = line.strip().split(" ")
if len(split) != 2:
print('format error')
sys.exit()
words.add(split[keep])
with open(output_file, "w") as wf:
for word in words:
wf.write(word + '\n')
wf.write('<UNK>\n')
def main():
if len(sys.argv) < 3:
print('usage: bil2mono.py <0/1> <dict.txt>')
sys.exit()
keep = sys.argv[1]
if keep != '0' and keep != '1':
print('usage: bil2mono.py <0/1> <dict.txt>')
sys.exit()
keep = int(keep)
dicos = sys.argv[2:]
for dico in dicos:
to_mono(dico, keep)
if __name__ == '__main__':
main()
import numpy as np
import sys
import os
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch
from allennlp.data import vocabulary
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pylab
bert = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(bert)
def anchors(corpus, sset, fmt, vocab_file):
vocab = vocabulary.Vocabulary()
vocab.set_from_file(vocab_file, oov_token='<UNK>')
oov_id = vocab.get_token_index(vocab._oov_token)
anchors = torch.zeros((vocab.get_vocab_size(), 768))
vocab_size = vocab.get_vocab_size()
num_occurrences = [0] * vocab_size
print(f'Loaded vocabulary of size {vocab_size} from file {vocab_file}')
if sset == 'all':
ssets = ['train', 'test', 'dev']
data = {}
for s in ssets:
data[s] = generate_sentence_list(corpus, s, fmt)
data = data['train'] + data['test'] + data['dev']
else:
data = generate_sentence_list(corpus, sset, fmt)
batch_size = 64
dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch)
bert_embeddings = BertModel.from_pretrained(bert)
#words for which to collect contextual embeddings in order to plot point clouds
clouds = ['the', 'who']
cloud_embeddings = {}
for cloud in clouds:
cloud_embeddings[cloud] = []
for sentence_batch in tqdm(dataloader):
bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state
for i, sentence in enumerate(sentence_batch.tokens):
bert_sentence_output = bert_output[i]
for j, token in enumerate(sentence):
bert_token_output = bert_sentence_output[j]
if token in clouds:
cloud_embeddings[token].append(bert_token_output)
w_id = vocab.get_token_index(token)
if w_id != oov_id:
n = num_occurrences[w_id]
anchors[w_id, :] = anchors[w_id, :] * (n / (n + 1)) + bert_token_output[:] / (n + 1)
num_occurrences[w_id] += 1
print("done computing anchors.")
pca = PCA(n_components=2)
anchors = anchors.detach().numpy()
save_embs(anchors, vocab, corpus, sset, fmt)
#plot_anchors(anchors, pca, vocab, oov_id)
#plot_clouds(cloud_embeddings, anchors, pca, vocab)
def save_embs(anchors, vocab, corpus, sset, fmt):
#words not in the corpus are not included
output_dir = 'saved_anchors'
output_file = os.path.join(output_dir, f'anchors_{corpus}_{sset}.{fmt}')
dico_anchors = {}
for i, anchor in enumerate(anchors):
if len(np.nonzero(anchor)[0]) > 0:
token = vocab.get_token_from_index(i)
dico_anchors[token] = anchor
nb_embs = len(list(dico_anchors.keys()))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
with open((output_file + '.txt'), "w") as wf:
wf.write(str(nb_embs) + " " + '768\n')
for token, anchor in dico_anchors.items():
wf.write(token + ' ' + ' '.join([str(v) for v in anchor]) + '\n')
#np.savez_compressed(output_file, anchors = np.array([dico_anchors]))
print(f"anchors saved at {output_file}")
def plot_anchors(anchors, pca, vocab, oov_id):
reduced = pca.fit_transform(anchors)
words = ['le','la','les','à','qui','du','de','juillet','août','juin','et','ou','pour']
for word in words:
index = vocab.get_token_index(word)
if index != oov_id: # and np.sum(anchors[index] != 0):
plt.plot(reduced[index][0], reduced[index][1], "+", label = word)
plt.title('A few word anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('anchors.png')
plt.clf()
def plot_clouds(cloud_embeddings, anchors, pca, vocab):
if len(cloud_embeddings.keys()) != 2:
print('to draw point clouds please enter exactly two tokens')
sys.exit()
tok0, tok1 = cloud_embeddings.keys()
colors = ['b', 'm']
embs0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok0]])
embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]])
nb0 = embs0.shape[0]
nb1 = embs1.shape[0]
if nb0 == 0 or nb1 == 1:
print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.')
sys.exit()
anchor0 = np.array([anchors[vocab.get_token_index(tok0)]])
anchor1 = np.array([anchors[vocab.get_token_index(tok1)]])
full_embs = np.concatenate((embs0, embs1, anchor0, anchor1), axis=0)
full_embs_reduced = pca.fit_transform(full_embs)
for emb in full_embs_reduced[:nb0]:
plt.plot(emb[0], emb[1], '.', color=colors[0])
for emb in full_embs_reduced[nb0:(nb0+nb1)]:
plt.plot(emb[0], emb[1], '.', color=colors[1])
plt.plot(full_embs_reduced[-2][0], full_embs_reduced[-2][1], '+', color=colors[0], label = tok0)
plt.plot(full_embs_reduced[-1][0], full_embs_reduced[-1][1], '+', color=colors[1], label = tok1)
plt.title('Contextual embeddings and corresponding anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('clouds.png')
plt.clf()
def main():
if len(sys.argv) != 5:
print("usage: gen_anchors.py <corpus> <train/test/dev/all> <conllu/tok> <vocab_file>")
sys.exit()
corpus = sys.argv[1]
sset = sys.argv[2]
fmt = sys.argv[3]
vocab_file = sys.argv[4]
vocab_file = os.path.join('dictionaries', 'monolingual', vocab_file)
if not os.path.isfile(vocab_file):
print(f'no vocab file at {vocab_file}.')
sys.exit()
anchors(corpus, sset, fmt, vocab_file)
if __name__ == '__main__':
main()
import numpy as np
import sys
import os
from transformers import BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
#from torchmetrics import F1Score
#from torch.autograd import Variable
from tqdm import tqdm
from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch
from allennlp.data import vocabulary
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pylab
def main():
anchors = np.load('saved_anchors/anchors_fra.sdrt.annodis_test.conllu.npz', allow_pickle = True)
anchors = anchors[anchors.files[0]]
print(anchors[0].shape)
if __name__ == '__main__':
main()
......@@ -43,8 +43,9 @@ class LSTM(nn.Module):
class SentenceBatch():
def __init__(self, sentence_ids, tok_ids, tok_types, tok_masks, labels):
def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels):
self.sentence_ids = sentence_ids
self.tokens = tokens
self.tok_ids = pad_sequence(tok_ids, batch_first=True)
self.tok_types = pad_sequence(tok_types, batch_first=True)
self.tok_masks = pad_sequence(tok_masks, batch_first=True)
......@@ -57,6 +58,7 @@ class SentenceBatch():
def generate_sentence_list(corpus, sset, fmt):
#move that part to parse_corpus.py
parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz")
#print("PATH", parsed_data)
if not os.path.isfile(parsed_data):
print("you must parse the corpus before training it")
sys.exit()
......@@ -92,6 +94,9 @@ def generate_sentence_list(corpus, sset, fmt):
indexed_sentences[i] = (i, sentence)
return indexed_sentences
def add_cls_sep(sentence):
return ['[CLS]'] + list(sentence) + ['[SEP]']
def toks_to_ids(sentence):
#print("sentence=", sentence)
......@@ -115,12 +120,13 @@ def collate_batch(batch):
sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch]
#mappings = [make_mapping(sentence) for sentence in token_batch]
labels = [make_labels(sentence) for sentence in label_batch]
tokens = [add_cls_sep(sentence) for sentence in token_batch]
tok_ids = [toks_to_ids(sentence) for sentence in token_batch]
lengths = [len(toks) for toks in tok_ids]
tok_types = [make_tok_types(l) for l in lengths]
tok_masks = [make_tok_masks(l) for l in lengths]
return SentenceBatch(sentence_ids, tok_ids, tok_types, tok_masks, labels)
return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels)
def train(corpus, fmt):
print(f'starting training of {corpus} in format {fmt}...')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment