diff --git a/code/contextual_embeddings/configs/bert.jsonnet b/code/contextual_embeddings/configs/bert.jsonnet index 4a0b9853d49ec75d85eaa9bee5dedb4e0195bff6..fec982ffc29bdcd98d693b026ac7390281ef2753 100644 --- a/code/contextual_embeddings/configs/bert.jsonnet +++ b/code/contextual_embeddings/configs/bert.jsonnet @@ -65,8 +65,8 @@ "lr": 0.001 }, "num_serialized_models_to_keep": 3, - "num_epochs": 10, + "num_epochs": 4, "grad_norm": 5.0, - "cuda_device": -1 + "cuda_device": 0 } } diff --git a/trytorch/bil2mono.py b/trytorch/bil2mono.py new file mode 100644 index 0000000000000000000000000000000000000000..e64e509ee1ce5f7d10749c2a21bbb155fac59f39 --- /dev/null +++ b/trytorch/bil2mono.py @@ -0,0 +1,45 @@ +import os +import sys + +def to_mono(dico, keep): + + output_dir = os.path.join('dictionaries', 'monolingual') + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + input_file = os.path.join('dictionaries', dico) + beg = 3 if keep else 0 + end = 5 if keep else 2 + output_file = os.path.join(output_dir, (dico[beg:end] + "_" + dico)) + + words = set() + + with open(input_file, "r") as rf: + for line in rf: + split = line.strip().split(" ") + if len(split) != 2: + print('format error') + sys.exit() + words.add(split[keep]) + + with open(output_file, "w") as wf: + for word in words: + wf.write(word + '\n') + wf.write('<UNK>\n') + +def main(): + if len(sys.argv) < 3: + print('usage: bil2mono.py <0/1> <dict.txt>') + sys.exit() + + keep = sys.argv[1] + if keep != '0' and keep != '1': + print('usage: bil2mono.py <0/1> <dict.txt>') + sys.exit() + keep = int(keep) + dicos = sys.argv[2:] + for dico in dicos: + to_mono(dico, keep) + +if __name__ == '__main__': + main() + diff --git a/trytorch/gen_anchors.py b/trytorch/gen_anchors.py new file mode 100644 index 0000000000000000000000000000000000000000..7204d1bab840f4d1dbf8b7c822b43882cc7d9277 --- /dev/null +++ b/trytorch/gen_anchors.py @@ -0,0 +1,159 @@ +import numpy as np +import sys +import os +from transformers import BertTokenizer, BertModel +from transformers.tokenization_utils_base import BatchEncoding +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +#from torchmetrics import F1Score +#from torch.autograd import Variable +from tqdm import tqdm +from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch +from allennlp.data import vocabulary +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import pylab + +bert = 'bert-base-multilingual-cased' +tokenizer = BertTokenizer.from_pretrained(bert) + +def anchors(corpus, sset, fmt, vocab_file): + vocab = vocabulary.Vocabulary() + vocab.set_from_file(vocab_file, oov_token='<UNK>') + oov_id = vocab.get_token_index(vocab._oov_token) + anchors = torch.zeros((vocab.get_vocab_size(), 768)) + vocab_size = vocab.get_vocab_size() + num_occurrences = [0] * vocab_size + print(f'Loaded vocabulary of size {vocab_size} from file {vocab_file}') + + if sset == 'all': + ssets = ['train', 'test', 'dev'] + data = {} + for s in ssets: + data[s] = generate_sentence_list(corpus, s, fmt) + data = data['train'] + data['test'] + data['dev'] + else: + data = generate_sentence_list(corpus, sset, fmt) + batch_size = 64 + + dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch) + bert_embeddings = BertModel.from_pretrained(bert) + + #words for which to collect contextual embeddings in order to plot point clouds + clouds = ['the', 'who'] + cloud_embeddings = {} + for cloud in clouds: + cloud_embeddings[cloud] = [] + + for sentence_batch in tqdm(dataloader): + bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state + + for i, sentence in enumerate(sentence_batch.tokens): + bert_sentence_output = bert_output[i] + + for j, token in enumerate(sentence): + bert_token_output = bert_sentence_output[j] + if token in clouds: + cloud_embeddings[token].append(bert_token_output) + + w_id = vocab.get_token_index(token) + if w_id != oov_id: + n = num_occurrences[w_id] + anchors[w_id, :] = anchors[w_id, :] * (n / (n + 1)) + bert_token_output[:] / (n + 1) + num_occurrences[w_id] += 1 + + print("done computing anchors.") + + pca = PCA(n_components=2) + anchors = anchors.detach().numpy() + save_embs(anchors, vocab, corpus, sset, fmt) + #plot_anchors(anchors, pca, vocab, oov_id) + #plot_clouds(cloud_embeddings, anchors, pca, vocab) + +def save_embs(anchors, vocab, corpus, sset, fmt): + #words not in the corpus are not included + output_dir = 'saved_anchors' + output_file = os.path.join(output_dir, f'anchors_{corpus}_{sset}.{fmt}') + dico_anchors = {} + for i, anchor in enumerate(anchors): + if len(np.nonzero(anchor)[0]) > 0: + token = vocab.get_token_from_index(i) + dico_anchors[token] = anchor + + nb_embs = len(list(dico_anchors.keys())) + + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + + with open((output_file + '.txt'), "w") as wf: + wf.write(str(nb_embs) + " " + '768\n') + for token, anchor in dico_anchors.items(): + wf.write(token + ' ' + ' '.join([str(v) for v in anchor]) + '\n') + + #np.savez_compressed(output_file, anchors = np.array([dico_anchors])) + print(f"anchors saved at {output_file}") + +def plot_anchors(anchors, pca, vocab, oov_id): + reduced = pca.fit_transform(anchors) + words = ['le','la','les','à','qui','du','de','juillet','août','juin','et','ou','pour'] + for word in words: + index = vocab.get_token_index(word) + if index != oov_id: # and np.sum(anchors[index] != 0): + plt.plot(reduced[index][0], reduced[index][1], "+", label = word) + + plt.title('A few word anchors (PCA)') + plt.legend() + plt.show() + plt.savefig('anchors.png') + plt.clf() + +def plot_clouds(cloud_embeddings, anchors, pca, vocab): + if len(cloud_embeddings.keys()) != 2: + print('to draw point clouds please enter exactly two tokens') + sys.exit() + tok0, tok1 = cloud_embeddings.keys() + colors = ['b', 'm'] + embs0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok0]]) + embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]]) + nb0 = embs0.shape[0] + nb1 = embs1.shape[0] + if nb0 == 0 or nb1 == 1: + print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.') + sys.exit() + + anchor0 = np.array([anchors[vocab.get_token_index(tok0)]]) + anchor1 = np.array([anchors[vocab.get_token_index(tok1)]]) + full_embs = np.concatenate((embs0, embs1, anchor0, anchor1), axis=0) + full_embs_reduced = pca.fit_transform(full_embs) + + for emb in full_embs_reduced[:nb0]: + plt.plot(emb[0], emb[1], '.', color=colors[0]) + for emb in full_embs_reduced[nb0:(nb0+nb1)]: + plt.plot(emb[0], emb[1], '.', color=colors[1]) + plt.plot(full_embs_reduced[-2][0], full_embs_reduced[-2][1], '+', color=colors[0], label = tok0) + plt.plot(full_embs_reduced[-1][0], full_embs_reduced[-1][1], '+', color=colors[1], label = tok1) + + plt.title('Contextual embeddings and corresponding anchors (PCA)') + plt.legend() + plt.show() + plt.savefig('clouds.png') + plt.clf() + +def main(): + if len(sys.argv) != 5: + print("usage: gen_anchors.py <corpus> <train/test/dev/all> <conllu/tok> <vocab_file>") + sys.exit() + corpus = sys.argv[1] + sset = sys.argv[2] + fmt = sys.argv[3] + vocab_file = sys.argv[4] + vocab_file = os.path.join('dictionaries', 'monolingual', vocab_file) + if not os.path.isfile(vocab_file): + print(f'no vocab file at {vocab_file}.') + sys.exit() + anchors(corpus, sset, fmt, vocab_file) + +if __name__ == '__main__': + main() diff --git a/trytorch/show_anchors.py b/trytorch/show_anchors.py new file mode 100644 index 0000000000000000000000000000000000000000..986cd94440f449e73b68dcd70c51dda5c3f41c80 --- /dev/null +++ b/trytorch/show_anchors.py @@ -0,0 +1,27 @@ +import numpy as np +import sys +import os +from transformers import BertTokenizer, BertModel +from transformers.tokenization_utils_base import BatchEncoding +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +#from torchmetrics import F1Score +#from torch.autograd import Variable +from tqdm import tqdm +from train_model_baseline import LSTM, SentenceBatch, generate_sentence_list, toks_to_ids, make_labels, make_tok_types, make_tok_masks, collate_batch +from allennlp.data import vocabulary +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import pylab + + +def main(): + anchors = np.load('saved_anchors/anchors_fra.sdrt.annodis_test.conllu.npz', allow_pickle = True) + anchors = anchors[anchors.files[0]] + + print(anchors[0].shape) + +if __name__ == '__main__': + main() diff --git a/trytorch/train_model_baseline.py b/trytorch/train_model_baseline.py index 218b576062f281a15c8bb16955818aa6faa1e234..4acbd84cf3809e3c764d78118548bfeb392c53dd 100644 --- a/trytorch/train_model_baseline.py +++ b/trytorch/train_model_baseline.py @@ -43,8 +43,9 @@ class LSTM(nn.Module): class SentenceBatch(): - def __init__(self, sentence_ids, tok_ids, tok_types, tok_masks, labels): + def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels): self.sentence_ids = sentence_ids + self.tokens = tokens self.tok_ids = pad_sequence(tok_ids, batch_first=True) self.tok_types = pad_sequence(tok_types, batch_first=True) self.tok_masks = pad_sequence(tok_masks, batch_first=True) @@ -57,6 +58,7 @@ class SentenceBatch(): def generate_sentence_list(corpus, sset, fmt): #move that part to parse_corpus.py parsed_data = os.path.join("parsed_data", f"parsed_{corpus}_{sset}.{fmt}.npz") + #print("PATH", parsed_data) if not os.path.isfile(parsed_data): print("you must parse the corpus before training it") sys.exit() @@ -92,6 +94,9 @@ def generate_sentence_list(corpus, sset, fmt): indexed_sentences[i] = (i, sentence) return indexed_sentences + +def add_cls_sep(sentence): + return ['[CLS]'] + list(sentence) + ['[SEP]'] def toks_to_ids(sentence): #print("sentence=", sentence) @@ -115,12 +120,13 @@ def collate_batch(batch): sentence_ids, token_batch, label_batch = [i for i, (_, _) in batch], [j for _, (j, _) in batch], [k for _, (_, k) in batch] #mappings = [make_mapping(sentence) for sentence in token_batch] labels = [make_labels(sentence) for sentence in label_batch] + tokens = [add_cls_sep(sentence) for sentence in token_batch] tok_ids = [toks_to_ids(sentence) for sentence in token_batch] lengths = [len(toks) for toks in tok_ids] tok_types = [make_tok_types(l) for l in lengths] tok_masks = [make_tok_masks(l) for l in lengths] - return SentenceBatch(sentence_ids, tok_ids, tok_types, tok_masks, labels) + return SentenceBatch(sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels) def train(corpus, fmt): print(f'starting training of {corpus} in format {fmt}...')