Skip to content
Snippets Groups Projects
Commit 4eb7729e authored by Alice Pain's avatar Alice Pain
Browse files

allencustom

parent 78235faf
No related branches found
No related tags found
No related merge requests found
from typing import Dict, List, Sequence, Iterable
import itertools
import logging
import os
from overrides import overrides
from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_utils import to_bioul
from allennlp.data.fields import TextField, SequenceLabelField, Field, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def _is_divider(line: str) -> bool:
empty_line = line.strip() == ''
if empty_line:
return True
else:
first_token = line.split()[0]
if first_token == "-DOCSTART-": # pylint: disable=simplifiable-if-statement
return True
else:
return False
@DatasetReader.register("custom_conll_reader")
class CustomConllDatasetReader(DatasetReader):
"""
Reads instances from a pretokenised file where each line is in the following format:
WORD POS-TAG CHUNK-TAG NER-TAG
with a blank line indicating the end of each sentence
and '-DOCSTART- -X- -X- O' indicating the end of each article,
and converts it into a ``Dataset`` suitable for sequence tagging.
Each ``Instance`` contains the words in the ``"tokens"`` ``TextField``.
The values corresponding to the ``tag_label``
values will get loaded into the ``"tags"`` ``SequenceLabelField``.
And if you specify any ``feature_labels`` (you probably shouldn't),
the corresponding values will get loaded into their own ``SequenceLabelField`` s.
This dataset reader ignores the "article" divisions and simply treats
each sentence as an independent ``Instance``. (Technically the reader splits sentences
on any combination of blank lines and "DOCSTART" tags; in particular, it does the right
thing on well formed inputs.)
Parameters
----------
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
We use this to define the input representation for the text. See :class:`TokenIndexer`.
tag_label: ``str``, optional (default=``ner``)
Specify `ner`, `pos`, or `chunk` to have that tag loaded into the instance field `tag`.
feature_labels: ``Sequence[str]``, optional (default=``()``)
These labels will be loaded as features into the corresponding instance fields:
``pos`` -> ``pos_tags``, ``chunk`` -> ``chunk_tags``, ``ner`` -> ``ner_tags``
Each will have its own namespace: ``pos_tags``, ``chunk_tags``, ``ner_tags``.
If you want to use one of the tags as a `feature` in your model, it should be
specified here.
coding_scheme: ``str``, optional (default=``IOB1``)
Specifies the coding scheme for ``ner_labels`` and ``chunk_labels``.
Valid options are ``IOB1`` and ``BIOUL``. The ``IOB1`` default maintains
the original IOB1 scheme in the CoNLL 2003 NER data.
In the IOB1 scheme, I is a token inside a span, O is a token outside
a span and B is the beginning of span immediately following another
span of the same type.
label_namespace: ``str``, optional (default=``labels``)
Specifies the namespace for the chosen ``tag_label``.
"""
_VALID_LABELS = {'ner', 'pos', 'chunk'}
def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
tag_label: str = "ner",
feature_labels: Sequence[str] = (),
lazy: bool = False,
coding_scheme: str = "IOB1",
label_namespace: str = "labels") -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
if tag_label is not None and tag_label not in self._VALID_LABELS:
raise ConfigurationError("unknown tag label type: {}".format(tag_label))
for label in feature_labels:
if label not in self._VALID_LABELS:
raise ConfigurationError("unknown feature label type: {}".format(label))
if coding_scheme not in ("IOB1", "BIOUL"):
raise ConfigurationError("unknown coding_scheme: {}".format(coding_scheme))
self.tag_label = tag_label
self.feature_labels = set(feature_labels)
self.coding_scheme = coding_scheme
self.label_namespace = label_namespace
self._original_coding_scheme = "IOB1"
@overrides
def _read(self, file_path: str) -> Iterable[Instance]:
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
with open(file_path, "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
# Group into alternative divider / sentence chunks.
for is_divider, lines in itertools.groupby(data_file, _is_divider):
# Ignore the divider chunks, so that `lines` corresponds to the words
# of a single sentence.
if not is_divider:
fields = [line.strip().split() for line in lines]
# unzipping trick returns tuples, but our Fields need lists
fields = [list(field) for field in zip(*fields)]
tokens_, pos_tags, chunk_tags, ner_tags = fields
# TextField requires ``Token`` objects
tokens = [Token(token) for token in tokens_]
yield self.text_to_instance(tokens, pos_tags, chunk_tags, ner_tags, file_path)
def get_lang(self, file_path):
_, file_name = os.path.split(file_path)
lang = file_name[:2]
if lang == 'po':
lang = 'pt'
if lang not in ['en','de','it','fr','pt','sv']:
raise ConfigurationError(f"Language {lang} not supported by ELMo")
return lang
def text_to_instance(self, # type: ignore
tokens: List[Token],
pos_tags: List[str] = None,
chunk_tags: List[str] = None,
ner_tags: List[str] = None,
file_path: str = None) -> Instance:
"""
We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
"""
# pylint: disable=arguments-differ
sequence = TextField(tokens, self._token_indexers)
instance_fields: Dict[str, Field] = {'tokens': sequence}
instance_fields["metadata"] = MetadataField({"words": [x.text for x in tokens], "lang": self.get_lang(file_path)})
# Recode the labels if necessary.
if self.coding_scheme == "BIOUL":
coded_chunks = to_bioul(chunk_tags,
encoding=self._original_coding_scheme) if chunk_tags is not None else None
coded_ner = to_bioul(ner_tags,
encoding=self._original_coding_scheme) if ner_tags is not None else None
else:
# the default IOB1
coded_chunks = chunk_tags
coded_ner = ner_tags
# Add "feature labels" to instance
if 'pos' in self.feature_labels:
if pos_tags is None:
raise ConfigurationError("Dataset reader was specified to use pos_tags as "
"features. Pass them to text_to_instance.")
instance_fields['pos_tags'] = SequenceLabelField(pos_tags, sequence, "pos_tags")
if 'chunk' in self.feature_labels:
if coded_chunks is None:
raise ConfigurationError("Dataset reader was specified to use chunk tags as "
"features. Pass them to text_to_instance.")
instance_fields['chunk_tags'] = SequenceLabelField(coded_chunks, sequence, "chunk_tags")
if 'ner' in self.feature_labels:
if coded_ner is None:
raise ConfigurationError("Dataset reader was specified to use NER tags as "
" features. Pass them to text_to_instance.")
instance_fields['ner_tags'] = SequenceLabelField(coded_ner, sequence, "ner_tags")
# Add "tag label" to instance
if self.tag_label == 'ner' and coded_ner is not None:
instance_fields['tags'] = SequenceLabelField(coded_ner, sequence,
self.label_namespace)
elif self.tag_label == 'pos' and pos_tags is not None:
instance_fields['tags'] = SequenceLabelField(pos_tags, sequence,
self.label_namespace)
elif self.tag_label == 'chunk' and coded_chunks is not None:
instance_fields['tags'] = SequenceLabelField(coded_chunks, sequence,
self.label_namespace)
return Instance(instance_fields)
from typing import Dict, Optional, List, Any
import numpy
from overrides import overrides
import torch
from torch.nn.modules.linear import Linear
import torch.nn.functional as F
from allennlp.common.checks import check_dimensions_match, ConfigurationError
from allennlp.data import Vocabulary
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure
@Model.register("custom_simple_tagger")
class CustomSimpleTagger(Model):
"""
This ``SimpleTagger`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then
predicts a tag for each token in the sequence.
Parameters
----------
vocab : ``Vocabulary``, required
A Vocabulary, required in order to compute sizes for input/output projections.
text_field_embedder : ``TextFieldEmbedder``, required
Used to embed the ``tokens`` ``TextField`` we get as input to the model.
encoder : ``Seq2SeqEncoder``
The encoder (with its own internal stacking) that we will use in between embedding tokens
and predicting output tags.
calculate_span_f1 : ``bool``, optional (default=``None``)
Calculate span-level F1 metrics during training. If this is ``True``, then
``label_encoding`` is required. If ``None`` and
label_encoding is specified, this is set to ``True``.
If ``None`` and label_encoding is not specified, it defaults
to ``False``.
label_encoding : ``str``, optional (default=``None``)
Label encoding to use when calculating span f1.
Valid options are "BIO", "BIOUL", "IOB1", "BMES".
Required if ``calculate_span_f1`` is true.
label_namespace : ``str``, optional (default=``labels``)
This is needed to compute the SpanBasedF1Measure metric, if desired.
Unless you did something unusual, the default value should be what you want.
verbose_metrics : ``bool``, optional (default = False)
If true, metrics will be returned per label class in addition
to the overall statistics.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self, vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
calculate_span_f1: bool = None,
label_encoding: Optional[str] = None,
label_namespace: str = "labels",
verbose_metrics: bool = False,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:
super(CustomSimpleTagger, self).__init__(vocab, regularizer)
self.label_namespace = label_namespace
self.text_field_embedder = text_field_embedder
self.num_classes = self.vocab.get_vocab_size(label_namespace)
self.encoder = encoder
self._verbose_metrics = verbose_metrics
self.tag_projection_layer = TimeDistributed(Linear(self.encoder.get_output_dim(),
self.num_classes))
check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
"text field embedding dim", "encoder input dim")
# We keep calculate_span_f1 as a constructor argument for API consistency with
# the CrfTagger, even it is redundant in this class
# (label_encoding serves the same purpose).
if calculate_span_f1 and not label_encoding:
raise ConfigurationError("calculate_span_f1 is True, but "
"no label_encoding was specified.")
self.metrics = {
"accuracy": CategoricalAccuracy(),
"accuracy3": CategoricalAccuracy(top_k=3)
}
if calculate_span_f1 or label_encoding:
self._f1_metric = SpanBasedF1Measure(vocab,
tag_namespace=label_namespace,
label_encoding=label_encoding)
else:
self._f1_metric = None
initializer(self)
@overrides
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
tags: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
----------
tokens : Dict[str, torch.LongTensor], required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
tags : torch.LongTensor, optional (default = None)
A torch tensor representing the sequence of integer gold class labels of shape
``(batch_size, num_tokens)``.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
metadata containing the original words in the sentence to be tagged under a 'words' key.
Returns
-------
An output dictionary consisting of:
logits : torch.FloatTensor
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
unnormalised log probabilities of the tag classes.
class_probabilities : torch.FloatTensor
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
a distribution of the tag classes per word.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
"""
embedded_text_input = self.text_field_embedder(tokens, lang=metadata[0]['lang']) #tokens)
batch_size, sequence_length, _ = embedded_text_input.size()
mask = get_text_field_mask(tokens)
encoded_text = self.encoder(embedded_text_input, mask)
logits = self.tag_projection_layer(encoded_text)
reshaped_log_probs = logits.view(-1, self.num_classes)
class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
sequence_length,
self.num_classes])
output_dict = {"logits": logits, "class_probabilities": class_probabilities}
if tags is not None:
loss = sequence_cross_entropy_with_logits(logits, tags, mask)
for metric in self.metrics.values():
metric(logits, tags, mask.float())
if self._f1_metric is not None:
self._f1_metric(logits, tags, mask.float())
output_dict["loss"] = loss
if metadata is not None:
output_dict["words"] = [x["words"] for x in metadata]
return output_dict
@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Does a simple position-wise argmax over each token, converts indices to string labels, and
adds a ``"tags"`` key to the dictionary with the result.
"""
all_predictions = output_dict['class_probabilities']
all_predictions = all_predictions.cpu().data.numpy()
if all_predictions.ndim == 3:
predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
else:
predictions_list = [all_predictions]
all_tags = []
for predictions in predictions_list:
argmax_indices = numpy.argmax(predictions, axis=-1)
tags = [self.vocab.get_token_from_index(x, namespace="labels")
for x in argmax_indices]
all_tags.append(tags)
output_dict['tags'] = all_tags
return output_dict
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_to_return = {metric_name: metric.get_metric(reset) for
metric_name, metric in self.metrics.items()}
if self._f1_metric is not None:
f1_dict = self._f1_metric.get_metric(reset=reset)
if self._verbose_metrics:
metrics_to_return.update(f1_dict)
else:
metrics_to_return.update({
x: y for x, y in f1_dict.items() if
"overall" in x})
return metrics_to_return
import os import os
import sys import sys
def to_mono(dico, keep): def write_mono(output_file, words):
with open(output_file, "w") as wf:
output_dir = os.path.join('dictionaries', 'monolingual') for word in words:
if not os.path.isdir(output_dir): wf.write(word + '\n')
os.mkdir(output_dir) wf.write('<UNK>\n')
input_file = os.path.join('dictionaries', dico)
beg = 3 if keep else 0
end = 5 if keep else 2
output_file = os.path.join(output_dir, (dico[beg:end] + "_" + dico))
words = set()
with open(input_file, "r") as rf: def to_mono(dico):
"""From a bilingual dictionary file sr-tg.txt, output two monolingual dictionary files sr_sr-tg.txt and tg_sr-tg.txt in the same directory"""
output_dir, dico_file = os.path.split(dico)
beg_src = 0
end_src = 2
beg_tgt = 3
end_tgt = 5
output_src = os.path.join(output_dir, (dico_file[beg_src:end_src] + "_" + dico_file))
output_tgt = os.path.join(output_dir, (dico_file[beg_tgt:end_tgt] + "_" + dico_file))
if os.path.isfile(output_src) and os.path.isfile(output_tgt):
print(f'Monolingual dictionaries {output_src} and {output_tgt} already exist.')
return output_src, output_tgt
words_src = set()
words_tgt = set()
with open(dico, "r") as rf:
for line in rf: for line in rf:
split = line.strip().split(" ") split = line.strip().split(" ")
if len(split) != 2: if len(split) != 2:
print('format error') print(f'Format error in {dico} file')
sys.exit() sys.exit()
words.add(split[keep]) words_src.add(split[0])
words_tgt.add(split[1])
write_mono(output_src, words_src)
write_mono(output_tgt, words_tgt)
print(f'Wrote monolingual dictionaries at {output_src} and {output_tgt}.')
with open(output_file, "w") as wf: return output_src, output_tgt
for word in words:
wf.write(word + '\n')
wf.write('<UNK>\n')
def main(): def main():
if len(sys.argv) < 3: if len(sys.argv) != 2:
print('usage: bil2mono.py <0/1> <dict.txt>') print('usage: bil2mono.py dict.txt')
sys.exit() sys.exit()
keep = sys.argv[1] dico = sys.argv[1]
if keep != '0' and keep != '1': to_mono(dico)
print('usage: bil2mono.py <0/1> <dict.txt>')
sys.exit()
keep = int(keep)
dicos = sys.argv[2:]
for dico in dicos:
to_mono(dico, keep)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
......
This diff is collapsed.
...@@ -13,21 +13,60 @@ import argparse ...@@ -13,21 +13,60 @@ import argparse
bert = 'bert-base-multilingual-cased' bert = 'bert-base-multilingual-cased'
def anchors(corpus, sset, fmt, vocab_file): def save_embs(anchors, vocab, output_file, output_dir):
#words not in the corpus are not included
dico_anchors = {}
for i, anchor in enumerate(anchors):
if len(np.nonzero(anchor)[0]) > 0:
token = vocab.get_token_from_index(i)
dico_anchors[token] = anchor
nb_embs = len(list(dico_anchors.keys()))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
#with open((output_file + '.txt'), "w") as wf:
# wf.write(str(nb_embs) + " " + '768\n')
# for token, anchor in dico_anchors.items():
# wf.write(token + ' ' + ' '.join([str(v) for v in anchor]) + '\n')
np.savez_compressed(output_file, anchors = np.array([dico_anchors]))
print(f"Anchors saved at {output_file}.")
return output_file
def anchors(corpus, sset, fmt, vocab_file, save_words=[]):
"""
corpus - corpus for which to compute anchors
sset - train/test/dev
fmt - conllu/tok
vocab_file - path to a text file containing one token per line, including <UNK> unknown token
save_words - a list of words for which to save contextual embeddings, for later visualization (optional)
---
Given a corpus and a set of tokens, compute average contextual embedding representation for each token"""
output_dir = 'saved_anchors'
source = f'{corpus}_{sset}.{fmt}'
output_file = os.path.join(output_dir, f'anchors_{source}.npz')
if os.path.isfile(output_file) and not save_words:
print(f'Anchors already computed at {output_file}.')
return output_file, {}, {}
print(f'Starting computation of anchors for corpus {source} from vocabulary file {vocab_file}.')
vocab = vocabulary.Vocabulary() vocab = vocabulary.Vocabulary()
vocab.set_from_file(vocab_file, oov_token='<UNK>') vocab.set_from_file(vocab_file, oov_token='<UNK>')
oov_id = vocab.get_token_index(vocab._oov_token) oov_id = vocab.get_token_index(vocab._oov_token)
anchors = torch.zeros((vocab.get_vocab_size(), 768)) anchors = torch.zeros((vocab.get_vocab_size(), 768), requires_grad=False)
vocab_size = vocab.get_vocab_size() vocab_size = vocab.get_vocab_size()
num_occurrences = [0] * vocab_size num_occurrences = [0] * vocab_size
print(f'Loaded vocabulary of size {vocab_size} from file {vocab_file}') print(f'Loaded vocabulary of size {vocab_size} from file {vocab_file}.')
if sset == 'all': if sset == 'all':
ssets = ['train', 'test', 'dev'] ssets = ['train', 'test', 'dev']
data = {} data = {}
for s in ssets: for s in ssets:
data[s] = generate_sentence_list(corpus, s, fmt) data[s] = generate_sentence_list(corpus, s, fmt)
data = data['train'] + data['test'] + data['dev'] data = data['train'] + data['test'] + data['dev'] #concatenate all sets
else: else:
data = generate_sentence_list(corpus, sset, fmt) data = generate_sentence_list(corpus, sset, fmt)
batch_size = 64 batch_size = 64
...@@ -35,22 +74,27 @@ def anchors(corpus, sset, fmt, vocab_file): ...@@ -35,22 +74,27 @@ def anchors(corpus, sset, fmt, vocab_file):
dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch) dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collate_batch)
bert_embeddings = BertModel.from_pretrained(bert) bert_embeddings = BertModel.from_pretrained(bert)
#words for which to collect contextual embeddings in order to plot point clouds #initialize dictionary to store embeddings for a few words (used for visualization afterwards)
clouds = ['le', 'qui'] saved_embs = {}
cloud_embeddings = {} saved_sentences = {}
for cloud in clouds: for word in save_words:
cloud_embeddings[cloud] = [] saved_embs[word] = []
saved_sentences[word] = []
#iterate through batches
for sentence_batch in tqdm(dataloader): for sentence_batch in tqdm(dataloader):
bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state bert_output = bert_embeddings(**sentence_batch.getBatchEncoding()).last_hidden_state
#iterate through sentences
for i, sentence in enumerate(sentence_batch.tokens): for i, sentence in enumerate(sentence_batch.tokens):
bert_sentence_output = bert_output[i] bert_sentence_output = bert_output[i]
#iterate through tokens
for j, token in enumerate(sentence): for j, token in enumerate(sentence):
bert_token_output = bert_sentence_output[j] bert_token_output = bert_sentence_output[j]
if token in clouds: if token in save_words:
cloud_embeddings[token].append(bert_token_output) saved_embs[token].append(bert_token_output.detach().numpy()) #save contextual embedding of token
saved_sentences[token].append(sentence) #save sentence context of token
w_id = vocab.get_token_index(token) w_id = vocab.get_token_index(token)
if w_id != oov_id: if w_id != oov_id:
...@@ -58,105 +102,12 @@ def anchors(corpus, sset, fmt, vocab_file): ...@@ -58,105 +102,12 @@ def anchors(corpus, sset, fmt, vocab_file):
anchors[w_id, :] = anchors[w_id, :] * (n / (n + 1)) + bert_token_output[:] / (n + 1) anchors[w_id, :] = anchors[w_id, :] * (n / (n + 1)) + bert_token_output[:] / (n + 1)
num_occurrences[w_id] += 1 num_occurrences[w_id] += 1
print("done computing anchors.")
pca = PCA(n_components=2)
anchors = anchors.detach().numpy() anchors = anchors.detach().numpy()
#save_embs(anchors, vocab, corpus, sset, fmt) #pca = PCA(n_components=2)
#plot_anchors(anchors, pca, vocab, oov_id) #plot_anchors(anchors, pca, vocab, oov_id)
#plot_clouds(cloud_embeddings, anchors, pca, vocab) #plot_clouds(cloud_embeddings, anchors, pca, vocab)
map_anchors(cloud_embeddings) #, anchors, vocab) #map_anchors(cloud_embeddings) #, anchors, vocab)
return save_embs(anchors, vocab, output_file, output_dir), saved_embs, saved_sentences
def save_embs(anchors, vocab, corpus, sset, fmt):
#words not in the corpus are not included
output_dir = 'saved_anchors'
output_file = os.path.join(output_dir, f'anchors_{corpus}_{sset}.{fmt}')
dico_anchors = {}
for i, anchor in enumerate(anchors):
if len(np.nonzero(anchor)[0]) > 0:
token = vocab.get_token_from_index(i)
dico_anchors[token] = anchor
nb_embs = len(list(dico_anchors.keys()))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
with open((output_file + '.txt'), "w") as wf:
wf.write(str(nb_embs) + " " + '768\n')
for token, anchor in dico_anchors.items():
wf.write(token + ' ' + ' '.join([str(v) for v in anchor]) + '\n')
#np.savez_compressed(output_file, anchors = np.array([dico_anchors]))
print(f"anchors saved at {output_file}")
def plot_anchors(anchors, pca, vocab, oov_id):
reduced = pca.fit_transform(anchors)
words = ['le','la','les','à','qui','du','de','juillet','août','juin','et','ou','pour']
for word in words:
index = vocab.get_token_index(word)
if index != oov_id: # and np.sum(anchors[index] != 0):
plt.plot(reduced[index][0], reduced[index][1], "+", label = word)
plt.title('A few word anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('anchors.png')
plt.clf()
def plot_clouds(cloud_embeddings, anchors, pca, vocab):
if len(cloud_embeddings.keys()) != 2:
print('to draw point clouds please enter exactly two tokens')
sys.exit()
tok0, tok1 = cloud_embeddings.keys()
colors = ['b', 'm']
embs0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok0]])
embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]])
nb0 = embs0.shape[0]
nb1 = embs1.shape[0]
if nb0 == 0 or nb1 == 0:
print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.')
sys.exit()
anchor0 = np.array([anchors[vocab.get_token_index(tok0)]])
anchor1 = np.array([anchors[vocab.get_token_index(tok1)]])
full_embs = np.concatenate((embs0, embs1, anchor0, anchor1), axis=0)
full_embs_reduced = pca.fit_transform(full_embs)
for emb in full_embs_reduced[:nb0]:
plt.plot(emb[0], emb[1], '.', color=colors[0])
for emb in full_embs_reduced[nb0:(nb0+nb1)]:
plt.plot(emb[0], emb[1], '.', color=colors[1])
plt.plot(full_embs_reduced[-2][0], full_embs_reduced[-2][1], '+', color=colors[0], label = tok0)
plt.plot(full_embs_reduced[-1][0], full_embs_reduced[-1][1], '+', color=colors[1], label = tok1)
plt.title('Contextual embeddings and corresponding anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('clouds.png')
plt.clf()
def map_anchors(cloud_embeddings): #, anchors, vocab):
tok = list(cloud_embeddings.keys())[0]
embs = cloud_embeddings[tok]
n = 10
if len(embs) > n:
embs = embs[:n]
else:
n = len(embs)
fig, axs = plt.subplots(nrows=1, ncols=n, subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axs.flat):
emb = embs[i].detach().numpy().reshape(768,1)
ax.imshow(emb, cmap='Blues', aspect='auto')
ax.set_title(f'({i+1})')
plt.title(f"A few contextual embeddings for the word '{tok}'")
plt.tight_layout()
plt.show()
plt.savefig('heatmap_anchors.png')
plt.clf()
def main(): def main():
parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list') parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list')
......
...@@ -22,12 +22,12 @@ def parse(corpus, rich): ...@@ -22,12 +22,12 @@ def parse(corpus, rich):
output_file = os.path.join(output_dir, f'parsed_{corpus}') output_file = os.path.join(output_dir, f'parsed_{corpus}')
if not os.path.isdir(output_dir): if not os.path.isdir(output_dir):
os.mkdir(output_dir) os.mkdir(output_dir)
if os.path.isfile(output_file + '.npz'): #if os.path.isfile(output_file + '.npz'):
print(f'{corpus} already parsed. do you wish to overwrite? (Y/n)') # print(f'{corpus} already parsed. do you wish to overwrite? (Y/n)')
user = input() # user = input()
if not (user == "" or user == "Y" or user == "y"): # if not (user == "" or user == "Y" or user == "y"):
print('done') # print('done')
sys.exit() # sys.exit()
column_names = ['tok_id','tok','lemma','upos','xpos','gram','idhead','deprel','type','label'] column_names = ['tok_id','tok','lemma','upos','xpos','gram','idhead','deprel','type','label']
#if corpus.startswith('eng.rst.rstdt_train') or corpus.startswith('por.rst.cstn'): #if corpus.startswith('eng.rst.rstdt_train') or corpus.startswith('por.rst.cstn'):
......
...@@ -9,13 +9,18 @@ from train_model_baseline import generate_sentence_list, collate_batch ...@@ -9,13 +9,18 @@ from train_model_baseline import generate_sentence_list, collate_batch
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import argparse import argparse
from scipy.spatial.distance import cosine
from scipy.spatial.transform import Rotation
def mapping(src_corpus, tgt_corpus, mapping, sset, fmt): def mapping(src_corpus, tgt_corpus, mapping, sset, fmt):
data = {} data = {}
mapping = torch.load(mapping) mapping = torch.load(mapping)
mapp = torch.from_numpy(mapping) #TEST mapping_tr = mapping.transpose()
print("MAPPING", mapping.shape) rotation = Rotation.from_matrix(mapping)
rotation_inv = rotation.inv()
print("rotation_inv", rotation_inv.as_matrix().shape)
#mapping = - mapping
if sset == 'all': if sset == 'all':
ssets = ['train', 'test', 'dev'] ssets = ['train', 'test', 'dev']
...@@ -35,14 +40,27 @@ def mapping(src_corpus, tgt_corpus, mapping, sset, fmt): ...@@ -35,14 +40,27 @@ def mapping(src_corpus, tgt_corpus, mapping, sset, fmt):
bert_embeddings = BertModel.from_pretrained(bert) bert_embeddings = BertModel.from_pretrained(bert)
#words for which to collect contextual embeddings in order to plot point clouds #words for which to collect contextual embeddings in order to plot point clouds
clouds = ['the', 'who', 'le', 'qui'] clouds = ['est','que','is','that']
clouds_fr = clouds[2:] clouds_en = clouds[2:]
cloud_embeddings = {} cloud_embeddings = {}
aligned_embeddings = {} aligned_embeddings = {}
for cloud in clouds: for cloud in clouds:
cloud_embeddings[cloud] = [] cloud_embeddings[cloud] = []
for cloud in clouds_fr: for cloud in clouds_en:
aligned_embeddings[cloud] = [] aligned_embeddings[cloud] = []
words_fr = ['est','que','le','pour','mais','et']
words_en = ['is','that','the','for','but','and']
fr = {}
fr_al = {}
en = {}
en_al = {}
for word in words_fr:
fr[word] = []
fr_al[word] = []
for word in words_en:
en[word] = []
en_al[word] = []
#write this as function #write this as function
for sentence_batch in tqdm(dataloader): for sentence_batch in tqdm(dataloader):
...@@ -52,50 +70,101 @@ def mapping(src_corpus, tgt_corpus, mapping, sset, fmt): ...@@ -52,50 +70,101 @@ def mapping(src_corpus, tgt_corpus, mapping, sset, fmt):
bert_sentence_output = bert_output[i] bert_sentence_output = bert_output[i]
for j, token in enumerate(sentence): for j, token in enumerate(sentence):
bert_token_output = bert_sentence_output[j] bert_token_output = bert_sentence_output[j].detach().numpy()
if token in clouds: if token in clouds:
cloud_embeddings[token].append(bert_token_output) cloud_embeddings[token].append(bert_token_output)
if token in clouds_fr: if token in clouds_en:
aligned_emb = np.matmul(bert_token_output.detach().numpy(), mapping.transpose()) aligned_emb = np.matmul(bert_token_output, mapping_tr)
aligned_embeddings[token].append(aligned_emb) aligned_embeddings[token].append(aligned_emb)
if token in words_fr:
fr[token].append(bert_token_output)
aligned = np.matmul(bert_token_output, mapping_tr)
#aligned = np.matmul(mapping_tr, bert_token_output.transpose())
#aligned = np.matmul(mapping, bert_token_output.transpose())
fr_al[token].append(aligned)
if token in words_en:
en[token].append(bert_token_output)
aligned = np.matmul(bert_token_output, mapping_tr)
#aligned = np.matmul(mapping_tr, bert_token_output.transpose())
#aligned = np.matmul(mapping, bert_token_output.transpose())
en_al[token].append(aligned)
pca = PCA(n_components=2) pca = PCA(n_components=2)
plot_clouds(cloud_embeddings, pca, 'before') #plot_clouds(cloud_embeddings, pca, 'before')
for cloud in clouds[:2]: for cloud in clouds[:2]:
aligned_embeddings[cloud] = cloud_embeddings[cloud] #add unchanged target vectors aligned_embeddings[cloud] = cloud_embeddings[cloud] #add unchanged target vectors
plot_clouds(aligned_embeddings, pca, 'After') #plot_clouds(aligned_embeddings, pca, 'after')
analyze(words_fr, words_en, fr, en, fr_al, en_al)
def plot_clouds(cloud_embeddings, pca, text): def plot_clouds(cloud_embeddings, pca, text):
tok_en0, tok_en1, tok_fr0, tok_fr1 = cloud_embeddings.keys() tok_en0, tok_en1, tok_fr0, tok_fr1 = cloud_embeddings.keys()
print(f'b= {tok_en0}')
print(f'c= {tok_en1}')
print(f'm= {tok_fr0}')
print(f'r= {tok_fr1}')
colors = ['b', 'c', 'm', 'r'] colors = ['b', 'c', 'm', 'r']
embs_en0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en0]]) embs_en0 = np.array([emb for emb in cloud_embeddings[tok_en0]])
embs_en1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_en1]]) embs_en1 = np.array([emb for emb in cloud_embeddings[tok_en1]])
embs_fr0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr0]]) embs_fr0 = np.array([emb for emb in cloud_embeddings[tok_fr0]])
embs_fr1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok_fr1]]) embs_fr1 = np.array([emb for emb in cloud_embeddings[tok_fr1]])
n_en0 = embs_en0.shape[0] n_en0 = embs_en0.shape[0]
n_en1 = embs_en1.shape[1] n_en1 = embs_en1.shape[0]
n_fr0 = embs_fr0.shape[0] n_fr0 = embs_fr0.shape[0]
n_fr1 = enbs_fr1.shape[1] n_fr1 = embs_fr1.shape[0]
full_embs = np.concatenate((embs_en0, embs_en1, embs_fr0, embs_fr1), axis=0) full_embs = np.concatenate((embs_en0, embs_en1, embs_fr0, embs_fr1), axis=0)
full_embs_reduced = pca.fit_transform(full_embs) embs_reduced = pca.fit_transform(full_embs)
for emb in full_embs_reduced[:n_en0]: print(n_en0, n_en1, n_fr0, n_fr1)
plt.plot(emb[0], emb[1], '.', color=colors[0]) transp = embs_reduced.transpose()
for emb in full_embs_reduced[n_en0:(n_en0+n_en1)]: red_en0 = transp[:,:n_en0]
plt.plot(emb[0], emb[1], '.', color=colors[1]) sep = n_en0 + n_en1
for emb in full_embs_reduced[(n_en0+n_en1):(n_en0+n_en1+n_fr0)]: red_en1 = transp[:,n_en0:sep]
plt.plot(emb[0], emb[1], '.', color=colors[2]) sep1 = sep + n_fr0
for emb in full_embs_reduced[(n_en0+n_en1+n_fr0):]: red_fr0 = transp[:,sep:sep1]
plt.plot(emb[0], emb[1], '.', color=colors[3]) red_fr1 = transp[:,sep1:]
plt.scatter(red_en0[0], red_en0[1], marker='.', color=colors[0], label = tok_en0)
plt.scatter(red_en1[0], red_en1[1], marker='.', color=colors[1], label = tok_en1)
plt.scatter(red_fr0[0], red_fr0[1], marker='.', color=colors[2], label = tok_fr0)
plt.scatter(red_fr1[0], red_fr1[1], marker='.', color=colors[3], label = tok_fr1)
plt.title(f'{text} alignment') plt.title(f'{text} alignment')
plt.legend() plt.legend()
plt.show() plt.show()
plt.savefig(f'{text}.png') plt.savefig(f'{text}.png')
plt.clf() plt.clf()
def analyze(words_fr, words_en, fr, en, fr_al, en_al):
print("EN aligned")
for word_fr in words_fr:
for word_en in words_en:
#dist_before = cosine(np.array(fr[word_fr]).mean(axis=0),np.array(en[word_en]).mean(axis=0))
dist_before = np.linalg.norm(np.array(fr[word_fr]).mean(axis=0) - np.array(en[word_en]).mean(axis=0))
#dist_after = cosine(np.array(fr[word_fr]).mean(axis=0),np.array(en_al[word_en]).mean(axis=0))
dist_after = np.linalg.norm(np.array(fr[word_fr]).mean(axis=0) - np.array(en_al[word_en]).mean(axis=0))
print(f'{word_fr} -- {word_en}')
print(f'Distance before = {dist_before}')
print(f'Distance after = {dist_after}')
print(f'Deplacement = {dist_before - dist_after}')
print('\n========\nFR aligned')
for word_fr in words_fr:
for word_en in words_en:
#dist_before = cosine(np.array(fr[word_fr]).mean(axis=0),np.array(en[word_en]).mean(axis=0))
dist_before = np.linalg.norm(np.array(fr[word_fr]).mean(axis=0) - np.array(en[word_en]).mean(axis=0))
#dist_after = cosine(np.array(fr_al[word_fr]).mean(axis=0),np.array(en[word_en]).mean(axis=0))
dist_after = np.linalg.norm(np.array(fr_al[word_fr]).mean(axis=0) - np.array(en[word_en]).mean(axis=0))
print(f'{word_fr} -- {word_en}')
print(f'Distance before = {dist_before}')
print(f'Distance after = {dist_after}')
print(f'Deplacement = {dist_before - dist_after}')
def main(): def main():
parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list') parser = argparse.ArgumentParser(description='Compute anchors for a given dataset and a given word list')
parser.add_argument('--src_corpus', required=True, help='corpus to align') parser.add_argument('--src_corpus', required=True, help='corpus to align')
......
import pickle
import torch
import numpy as np
import argparse
import sys
import os
from scipy.spatial.distance import cosine
from gen_anchors import anchors
from bil2mono import to_mono
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
#d = 768
#n = 100
#A = np.random.rand(d, n)
#B = np.random.rand(d, n)
def solve_procrustes(A, B):
"""Return W = argmin || WA - B || (Frobenius norm) where W is an orthogonal matrix"""
A_t = A.transpose()
M = np.matmul(B, A_t)
U, sigma, Vh = np.linalg.svd(M)
W = np.matmul(U, Vh)
return W
def compute_distance(src, tgt, src_word, tgt_word, W_t):
"""Compute cosine distance between source word embedding and target word embedding, before and after alignment"""
src_emb = src[src_word]
tgt_emb = tgt[tgt_word]
src_emb_al = np.matmul(src_emb, W_t)
dist_before = cosine(src_emb, tgt_emb)
dist_after = cosine(src_emb_al, tgt_emb)
return dist_before, dist_after
def eval_alignment(aligning_pairs, src, tgt, W_t, show_examples):
"""Compute distance before and after alignment for word translation pairs and random word pairs.
Print max_print examples to standard output"""
l = len(aligning_pairs)
align_transl = [0] * l
align_non_transl = [0] * l
if show_examples: print('** Translation pairs **')
for i, (src_word, tgt_word) in enumerate(aligning_pairs):
dist_before, dist_after = compute_distance(src, tgt, src_word, tgt_word, W_t)
align_transl[i] = dist_before - dist_after
if i < show_examples:
print(f'{src_word} - {tgt_word}')
print(f'Cosine distance before alignment: {dist_before}')
print(f'Cosine distance after alignment: {dist_after}\n')
if show_examples: print('** Non-translation pairs **')
for i, (src_word, _) in enumerate(aligning_pairs):
tgt_word = aligning_pairs[(i+10)%l][1]
dist_before, dist_after = compute_distance(src, tgt, src_word, tgt_word, W_t)
align_non_transl[i] = dist_before - dist_after
if i < show_examples:
print(f'{src_word} - {tgt_word}')
print(f'Cosine distance before alignment: {dist_before}')
print(f'Cosine distance after alignment: {dist_after}\n')
print("Mean alignment for translation pairs", sum(align_transl) / l)
print("Mean alignment for non-translation pairs", sum(align_non_transl) / l)
def compute_mapping(src_anchors, tgt_anchors, src_corpus, tgt_corpus, voc, show_examples, plot_examples, src_saved, tgt_saved):
"""
src_anchors - path to .npz file storing a dictionary from tokens to anchors (source language)
tgt_anchors - path to .npz file storing a dictionary from tokens to anchors (target language)
---
Compute the best orthogonal mapping to align the source vectors to the target vectors"""
output_dir = 'alignments'
output_file = os.path.join(output_dir, f'{src_corpus}_{tgt_corpus}.pth')
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
if os.path.isfile(output_file):
print(f'Mapping already computed at {output_file}.')
#sys.exit()
print(f'Starting computation of alignment matrix from {src_corpus} to {tgt_corpus}.')
src = np.load(src_anchors, allow_pickle=True)
src = src[src.files[0]][0]
tgt = np.load(tgt_anchors, allow_pickle=True)
tgt = tgt[tgt.files[0]][0]
# collect word pairs for which we have anchors
src_align = []
tgt_align = []
aligning_pairs = []
with open(voc, "r") as f:
for line in f:
# a line has format 'src_word tgt_word'
split = line.strip().split()
if len(split) != 2:
print(f'format error in {voc}')
sys.exit()
src_word, tgt_word = split
if src_word in src.keys() and tgt_word in tgt.keys():
src_align.append(src[src_word])
tgt_align.append(tgt[tgt_word])
aligning_pairs.append((src_word, tgt_word))
A = np.array(src_align).transpose()
B = np.array(tgt_align).transpose()
W = solve_procrustes(A, B)
torch.save(torch.from_numpy(W), output_file)
print(f'Mapping saved at {output_file}.')
W_t = W.transpose()
eval_alignment(aligning_pairs, src, tgt, W_t, show_examples)
return W
#if plot_examples:
# #plot_anchor_alignment(aligning_pairs, src, tgt, W_t)
# plot_alignment(src_saved, tgt_saved, W_t)
def get_dico(src_lang, tgt_lang):
"""
Convert DISRPT language identifier to 2-character language identifier
spa -> es
tur -> tr
others -> first two characters
"""
convert = {'spa':'es', 'tur':'tr'}
if src_lang in convert.keys():
src_lang = convert[src_lang]
elif len(src_lang) == 3:
src_lang = src_lang[:2]
if tgt_lang in convert.keys():
tgt_lang = convert[tgt_lang]
elif len(tgt_lang) == 3:
tgt_lang = tgt_lang[:2]
voc_file = os.path.join('dictionaries', f'{src_lang}-{tgt_lang}.txt')
if not os.path.isfile(voc_file):
print(f'File {voc_file} not found: either wrong path or this language configuration is unavailable.')
sys.exit()
return voc_file, src_lang, tgt_lang
def main():
"""Two uses:
- pass as argument pre-computed anchors (.npz file)
- pass as argument corpuses for which to compute anchors first (anchors are not recomputed if they already were). By default anchors are computed on the concatenation of train, test and dev sets in conllu format ('all' format)."""
parser = argparse.ArgumentParser(description='Compute best rotation from source embeddings to target embeddings')
parser.add_argument('--src_embs', help='embeddings to align (.npz)')
parser.add_argument('--tgt_embs', help='embeddings to align to (.npz)')
parser.add_argument('--src_corpus', required=True, help='corpus for which to compute anchors to align')
parser.add_argument('--tgt_corpus', required=True, help='corpus for which to compute anchors to align to')
parser.add_argument('--src_set', default='all', help='train/test/dev/all')
parser.add_argument('--tgt_set', default='all', help='train/test/dev/all')
parser.add_argument('--src_format', default='conllu', help='tok/conllu')
parser.add_argument('--tgt_format', default='conllu', help='tok/conllu')
parser.add_argument('--src_lang', help='source language for bilingual dictionary (de/en/es/fa/fr/nl/pt/ru/tr/zh or deu/eng/fas/fra/nld/por/rus/spa/tur/zho), assuming its path is dictionaries/src-tgt.txt')
parser.add_argument('--tgt_lang', help='target language for bilingual dictionary (de/en/es/fa/fr/nl/pt/ru/tr/zh or deu/eng/fas/fra/nld/por/rus/spa/tur/zho)')
parser.add_argument('--voc', help='bilingual dictionary for alignment (text file with two tokens per line separated by a space)')
parser.add_argument('--show_examples', type=int, default=3, help='number of word pairs for which to print distance')
parser.add_argument('--plot_examples', action='store_true', default=False, help='Plot a few word pairs for visualization. Available languages: de, en, fr (to add a language, modify common_words dictionary).')
params = parser.parse_args()
common_words = {'de':['ist','ein'],'en':['is','a'],'fr':['est','un']}
if params.voc: voc = params.voc
elif params.src_lang and params.tgt_lang:
src_lang = params.src_lang
tgt_lang = params.tgt_lang
else:
src_lang = params.src_corpus.split('.')[0]
tgt_lang = params.tgt_corpus.split('.')[0]
voc, src_lang, tgt_lang = get_dico(src_lang, tgt_lang)
if params.src_embs and params.tgt_embs:
src_embs = params.src_embs
tgt_embs = params.tgt_embs
else:
src_voc, tgt_voc = to_mono(voc)
if params.plot_examples and src_lang in common_words.keys() and tgt_lang in common_words.keys():
src_words = common_words[src_lang]
tgt_words = common_words[tgt_lang]
else:
src_words = []
tgt_words = []
src_embs, src_saved, src_sents = anchors(params.src_corpus, params.src_set, params.src_format, src_voc, save_words=src_words)
tgt_embs, tgt_saved, tgt_sents = anchors(params.tgt_corpus, params.tgt_set, params.tgt_format, tgt_voc, save_words=tgt_words)
#with open('src_saved','wb') as f:
# pickle.dump(src_saved, f)
#with open('src_sents','wb') as f:
# pickle.dump(src_sents, f)
#with open('tgt_saved','wb') as f:
# pickle.dump(tgt_saved, f)
#with open('tgt_sents','wb') as f:
# pickle.dump(tgt_sents, f)
mapping = compute_mapping(src_embs, tgt_embs, params.src_corpus, params.tgt_corpus, voc, params.show_examples)
if params.plot_examples:
src_word = next(iter(src_saved.keys()))
plot_heatmap(src_word, src_saved, src_corpus, src_sents)
tgt_word = next(iter(tgt_saved.keys()))
plot_heatmap(tgt_word, tgt_saved, tgt_corpus, tgt_sents)
plot_rotation(src_saved, tgt_saved, W, params.src_corpus, params.tgt_corpus, plot_anchors=True)
#, params.plot_examples, src_saved, tgt_saved)
if __name__ == '__main__':
main()
...@@ -11,10 +11,16 @@ from train_model_rich import RichLSTM, generate_rich_sentence_list, collate_rich ...@@ -11,10 +11,16 @@ from train_model_rich import RichLSTM, generate_rich_sentence_list, collate_rich
bert = 'bert-base-multilingual-cased' bert = 'bert-base-multilingual-cased'
def test(model_path, model_type, corpus, test_set, fmt, show_errors): def test(model_path, model_type, corpus, test_set, fmt, alignment, show_errors):
model = torch.load(model_path) model = torch.load(model_path)
print(f'Model:\t{model_path}\nType:\t{model_type}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}') print(f'Model:\t{model_path}\nType:\t{model_type}\nEval:\t{corpus}_{test_set}\nFormat:\t{fmt}')
if alignment:
alignment = torch.load(alignment)
model.loadAlignment(alignment)
else:
model.alignment = None
batch_size = 32 batch_size = 32
if model_type == 'baseline': if model_type == 'baseline':
data = generate_sentence_list(corpus, test_set, fmt) data = generate_sentence_list(corpus, test_set, fmt)
...@@ -57,9 +63,9 @@ def test(model_path, model_type, corpus, test_set, fmt, show_errors): ...@@ -57,9 +63,9 @@ def test(model_path, model_type, corpus, test_set, fmt, show_errors):
total_acc += sum_score total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0) total_loss += loss.item() #*label_batch.size(0)
precision = tp / (tp + fp) precision = tp / (tp + fp) if (tp + fp != 0) else 'n/a'
recall = tp / (tp + fn) recall = tp / (tp + fn) if (tp + fn != 0) else 'n/a'
f1 = 2 * (precision * recall) / (precision + recall) f1 = 2 * (precision * recall) / (precision + recall) if (precision != 'n/a' and recall != 'n/a') else 'n/a'
if show_errors > 0: if show_errors > 0:
print_errors(errors, data, max_print=show_errors) print_errors(errors, data, max_print=show_errors)
...@@ -94,6 +100,7 @@ def main(): ...@@ -94,6 +100,7 @@ def main():
parser.add_argument('--corpus', help='corpus to test on') parser.add_argument('--corpus', help='corpus to test on')
parser.add_argument('--set', default='test', help='portion of the corpus to test on') parser.add_argument('--set', default='test', help='portion of the corpus to test on')
parser.add_argument('--errors', type=int, default=0, help='number of prediction errors to display on standard output') parser.add_argument('--errors', type=int, default=0, help='number of prediction errors to display on standard output')
parser.add_argument('--alignment', help='alignment matrix')
params = parser.parse_args() params = parser.parse_args()
...@@ -103,7 +110,7 @@ def main(): ...@@ -103,7 +110,7 @@ def main():
if params.type == 'rich' and params.format == 'tok': if params.type == 'rich' and params.format == 'tok':
print('a rich model requires a .conllu file') print('a rich model requires a .conllu file')
sys.exit() sys.exit()
test(params.model, params.type, params.corpus, params.set, params.format, params.errors) test(params.model, params.type, params.corpus, params.set, params.format, params.alignment, params.errors)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -24,10 +24,13 @@ class LSTM(nn.Module): ...@@ -24,10 +24,13 @@ class LSTM(nn.Module):
d = 2 if bidirectional else 1 d = 2 if bidirectional else 1
self.hiddenToLabel = nn.Linear(d * hidden_size, 1) self.hiddenToLabel = nn.Linear(d * hidden_size, 1)
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
self.alignment = None
def forward(self, batch): def forward(self, batch):
output = self.bert_embeddings(**batch) output = self.bert_embeddings(**batch)
output768 = output.last_hidden_state output768 = output.last_hidden_state
if self.alignment is not None:
output768 = torch.matmul(output768, self.alignment)
output64, (last_hidden_state, last_cell_state) = self.lstm(output768) output64, (last_hidden_state, last_cell_state) = self.lstm(output768)
#output64, self.hidden = self.lstm(output768, self.hidden) #output64, self.hidden = self.lstm(output768, self.hidden)
#print("output64=", output64.shape) #print("output64=", output64.shape)
...@@ -37,7 +40,9 @@ class LSTM(nn.Module): ...@@ -37,7 +40,9 @@ class LSTM(nn.Module):
#print("output1=", output1.shape) #print("output1=", output1.shape)
return self.act(output1[:,:,0]) return self.act(output1[:,:,0])
#lstm_out, self.hidden = self.lstm(output, self.hidden) def loadAlignment(self, alignment):
#alignment = torch.tensor(alignment)
self.alignment = torch.transpose(alignment, 0, 1)
class SentenceBatch(): class SentenceBatch():
def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None): def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, labels, uposes = None, deprels = None, dheads = None):
...@@ -193,8 +198,10 @@ def train(corpus, fmt): ...@@ -193,8 +198,10 @@ def train(corpus, fmt):
total_acc += sum_score total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0) total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2) precision = tp / (tp + fp) if (tp + fp != 0) else 'n/a'
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") recall = tp / (tp + fn) if (tp + fn != 0) else 'n/a'
f1 = 2 * (precision * recall) / (precision + recall) if (precision != 'n/a' and recall != 'n/a') else 'n/a'
print(f"Epoch {epoch}\nAcc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n")
print('done training') print('done training')
output_file = save_model(model, 'baseline', corpus, params) output_file = save_model(model, 'baseline', corpus, params)
...@@ -223,7 +230,7 @@ def main(): ...@@ -223,7 +230,7 @@ def main():
parser = argparse.ArgumentParser(description='Train baseline model') parser = argparse.ArgumentParser(description='Train baseline model')
parser.add_argument('--corpus', required=True, help='corpus to train') parser.add_argument('--corpus', required=True, help='corpus to train')
parser.add_argument('--format', default='conllu', help='tok or conllu') parser.add_argument('--format', default='conllu', help='tok or conllu')
params = parse.parse_args() params = parser.parse_args()
train(params.corpus, params.format) train(params.corpus, params.format)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -39,6 +39,8 @@ class RichLSTM(nn.Module): ...@@ -39,6 +39,8 @@ class RichLSTM(nn.Module):
self.bert_upos = self.bert_emb_dim + self.upos_emb_dim self.bert_upos = self.bert_emb_dim + self.upos_emb_dim
self.bert_upos_deprel = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim self.bert_upos_deprel = self.bert_emb_dim + self.upos_emb_dim + self.deprel_emb_dim
self.alignment = None
def forward(self, tok_batch, upos_batch, deprel_batch, dhead_batch): def forward(self, tok_batch, upos_batch, deprel_batch, dhead_batch):
# batch: [B x L], where # batch: [B x L], where
# B = batch_size # B = batch_size
...@@ -46,6 +48,8 @@ class RichLSTM(nn.Module): ...@@ -46,6 +48,8 @@ class RichLSTM(nn.Module):
bert_output = self.bert_embeddings(**tok_batch) bert_output = self.bert_embeddings(**tok_batch)
bert_output = bert_output.last_hidden_state bert_output = bert_output.last_hidden_state
if self.alignment is not None:
bert_output = torch.matmul(bert_output, self.alignment)
# bert_output: [B x L x 768] # bert_output: [B x L x 768]
upos_output = self.upos_embeddings(upos_batch) upos_output = self.upos_embeddings(upos_batch)
# upos_output: [B x L x U] # upos_output: [B x L x U]
...@@ -66,6 +70,9 @@ class RichLSTM(nn.Module): ...@@ -66,6 +70,9 @@ class RichLSTM(nn.Module):
#lstm_out, self.hidden = self.lstm(output, self.hidden) #lstm_out, self.hidden = self.lstm(output, self.hidden)
def loadAlignment(self, alignment):
self.alignment = torch.transpose(alignment, 0, 1)
#class RichSentenceBatch(): #class RichSentenceBatch():
# #
# def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, upos, deprels, dheads, labels): # def __init__(self, sentence_ids, tokens, tok_ids, tok_types, tok_masks, upos, deprels, dheads, labels):
...@@ -215,9 +222,11 @@ def train(corpus): ...@@ -215,9 +222,11 @@ def train(corpus):
total_acc += sum_score total_acc += sum_score
total_loss += loss.item() #*label_batch.size(0) total_loss += loss.item() #*label_batch.size(0)
f1 = tp / (tp + (fp + fn) / 2) precision = tp / (tp + fp) if (tp + fp != 0) else 'n/a'
print(f"Epoch {epoch} Accuracy {total_acc/l} Loss {total_loss/l} F1 {f1}") recall = tp / (tp + fn) if (tp + fn != 0) else 'n/a'
f1 = 2 * (precision * recall) / (precision + recall) if (precision != 'n/a' and recall != 'n/a') else 'n/a'
print(f"Epoch {epoch}\nAcc\t{total_acc/l}\nLoss\t{total_loss/l}\nP\t{precision}\nR\t{recall}\nF1\t{f1}\n")
print('done training') print('done training')
output_file = save_model(model, 'rich', corpus, params) output_file = save_model(model, 'rich', corpus, params)
print(f'model saved at {output_file}') print(f'model saved at {output_file}')
......
import os
import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import pickle
colors = ['c','m','g','r','b','y','k','w']
n_colors = len(colors)
output_dir = 'images'
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
def plot_emb_matrix(M, pca, tsne, n_pairs, src_words, tgt_words, align):
M_pca = pca.fit_transform(M)
M_tsne = tsne.fit_transform(M_pca)
print("M_tsne", M_tsne.shape)
for i in range(n_pairs):
plt.plot(M_tsne[i][0], M_tsne[i][1], '.', color='m', label=src_words[i])
plt.plot(M_tsne[i+n_pairs][0], M_tsne[i+n_pairs][1], '.', color='b', label=tgt_words[i])
plt.title(f'{align} alignment')
plt.legend()
plt.show()
plt.savefig(f'images/n{align}.png')
plt.clf()
def plot_anchor_alignment(aligning_pairs, src, tgt, W_t):
n_pairs = 4
if len(aligning_pairs) < n_pairs:
print("not enough pairs to plot")
return
n_words = n_pairs * 2
before = np.zeros((n_words,768))
after = np.zeros((n_words,768))
src_words = [0] * n_pairs
tgt_words = [0] * n_pairs
for i, (src_word, tgt_word) in enumerate(aligning_pairs[:n_pairs]):
src_emb = src[src_word]
tgt_emb = tgt[tgt_word]
before[i] = src_emb
before[i+n_pairs] = tgt_emb
after[i] = np.matmul(src_emb, W_t)
after[i+n_pairs] = tgt_emb
src_words[i] = src_word
tgt_words[i] = tgt_word
pca = PCA(n_components=n_words)
tsne = TSNE(n_components=2)
plot_emb_matrix(before, pca, tsne, n_pairs, src_words, tgt_words, 'Before')
plot_emb_matrix(after, pca, tsne, n_pairs, src_words, tgt_words, 'After')
def plot_alignment(src, tgt, W_t):
src_words = list(src.keys())
tgt_words = list(tgt.keys())
before_src = []
before_tgt = []
after_src = []
src_boundaries = []
src_tag = 0
tgt_boundaries = []
tgt_tag = 0
for src_word, tgt_word in zip(src_words, tgt_words):
src_embs = src[src_word]
tgt_embs = tgt[tgt_word]
before_src = before_src + src_embs
before_tgt = before_tgt + tgt_embs
src_tag += len(src_embs)
src_boundaries.append(src_tag)
tgt_tag += len(tgt_embs)
tgt_boundaries.append(tgt_tag)
full_before = before_src + before_tgt
after_src = np.matmul(np.array(before_src), W_t)
after_tgt = np.array(before_tgt)
src_tgt_boundary = len(before_src)
before_embs = np.array(full_before)
after_embs = np.concatenate((after_src, after_tgt))
boundaries = src_boundaries + list(map(lambda x: x + src_tgt_boundary, tgt_boundaries))
print("boundaries", boundaries)
print("len full", len(full_before))
n_comp = min(len(full_before), 50)
pca = PCA(n_components=n_comp)
tsne = TSNE(n_components=2)
before_embs_pca = pca.fit_transform(before_embs)
before_embs_tsne = tsne.fit_transform(before_embs_pca)
after_embs_pca = pca.fit_transform(after_embs)
after_embs_tsne = tsne.fit_transform(after_embs_pca)
colors = ['b','c','m','r']
def plot(reduced, caption):
for i, word in enumerate(src_words + tgt_words):
beg = boundaries[i-1] if i >= 1 else 0
end = boundaries[i]
cloud = reduced[beg:end]
cloud = cloud.transpose()
print("cloud", cloud.shape)
plt.scatter(cloud[0], cloud[1], marker='.', color=colors[i], label=word)
plt.legend()
plt.savefig(f'images/{caption}.png')
plt.clf()
plot(before_embs_tsne, 'beforee')
plot(after_embs_tsne, 'afteree')
def plot_anchors(anchors, pca, vocab, oov_id):
reduced = pca.fit_transform(anchors)
words = ['le','la','les','à','qui','du','de','juillet','août','juin','et','ou','pour']
for word in words:
index = vocab.get_token_index(word)
if index != oov_id: # and np.sum(anchors[index] != 0):
plt.plot(reduced[index][0], reduced[index][1], "+", label = word)
plt.title('A few word anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('anchors.png')
plt.clf()
def plot_clouds(cloud_embeddings, anchors, pca, vocab):
if len(cloud_embeddings.keys()) != 2:
print('to draw point clouds please enter exactly two tokens')
sys.exit()
tok0, tok1 = cloud_embeddings.keys()
colors = ['b', 'm']
embs0 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok0]])
embs1 = np.array([emb.detach().numpy() for emb in cloud_embeddings[tok1]])
nb0 = embs0.shape[0]
nb1 = embs1.shape[0]
if nb0 == 0 or nb1 == 0:
print('one or more tokens have no occurrences in the corpus. unable to draw point cloud.')
sys.exit()
anchor0 = np.array([anchors[vocab.get_token_index(tok0)]])
anchor1 = np.array([anchors[vocab.get_token_index(tok1)]])
full_embs = np.concatenate((embs0, embs1, anchor0, anchor1), axis=0)
full_embs_reduced = pca.fit_transform(full_embs)
for emb in full_embs_reduced[:nb0]:
plt.plot(emb[0], emb[1], '.', color=colors[0])
for emb in full_embs_reduced[nb0:(nb0+nb1)]:
plt.plot(emb[0], emb[1], '.', color=colors[1])
plt.plot(full_embs_reduced[-2][0], full_embs_reduced[-2][1], '+', color=colors[0], label = tok0)
plt.plot(full_embs_reduced[-1][0], full_embs_reduced[-1][1], '+', color=colors[1], label = tok1)
plt.title('Contextual embeddings and corresponding anchors (PCA)')
plt.legend()
plt.show()
plt.savefig('clouds.png')
plt.clf()
def plot_heatmap(word, embeddings, corpus, sentences):
#GOOD
"""
word - a word for which to plot heatmap
embeddings - a dictionary from tokens to a list of contextual embeddings
sentences - a dictionary from tokens to the sentence their appeared in
corpus - a string for naming the image file
---
Plot the first n_dim dimensions of a few embeddings for a word as a heatmap.
"""
embs = embeddings[word]
sents = sentences[word]
n_dim = 100
n = 5
if len(embs) > n:
embs = embs[:n]
sents = sents[:n]
else:
n = len(embs)
fig, axs = plt.subplots(nrows=1, ncols=n, subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axs.flat):
emb = embs[i].reshape(768,1)
emb = emb[:n_dim,:]
im = ax.imshow(emb, cmap='Blues', aspect='auto')
ax.set_title(f'({i+1})')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.show()
file_name = f'heatmap_{corpus}_{word}'
plt.savefig(os.path.join(output_dir, file_name + '.png'))
plt.clf()
with open(os.path.join(output_dir, file_name + '.txt'), "w") as f:
f.write(f'>> {word} <<\n')
for i, sent in enumerate(sents):
sent = ' '.join(sent[1:-1])
f.write(f'({i+1}). {sent}\n')
def plot_clouds(src_embeddings, tgt_embeddings=None, pca=None, plot_anchors=False, caption=None):
#GOOD
"""
src_embeddings - a dictionary from tokens to a list of contextual embeddings
tgt_embeddings - a second dictionary (optional)
pca - already fit pca (optional)
plot_anchors - whether to plot the mean of each token cloud (default: False)
caption - saved image name
---
Plot embeddings as points in a 2D space, using a color for each token.
"""
#if no pca was given: first concatenate all embeddings as a single matrix, then fit PCA
if not pca:
full_embs = []
for embs in src_embeddings.values():
full_embs += embs
if tgt_embeddings:
for embs in tgt_embeddings.values():
full_embs += embs
full_embs = np.array(full_embs)
pca = PCA(n_components=2).fit(full_embs)
#go through embedding clouds, apply pca reduction, then plot
for i, (word, embs) in enumerate(src_embeddings.items()):
reduced = pca.transform(np.array(embs)).transpose()
plt.scatter(reduced[0], reduced[1], color=colors[i%n_colors], label=word, s=1)
if plot_anchors:
anchor = np.array([np.mean(np.array(embs), axis=0)])
reduced_anchor = pca.transform(anchor)[0]
plt.plot(reduced_anchor[0], reduced_anchor[1], marker='X', color='k', ms=10)
if tgt_embeddings:
for i, (word, embs) in enumerate(tgt_embeddings.items()):
reduced = pca.transform(np.array(embs)).transpose()
plt.scatter(reduced[0], reduced[1], color=colors[i+len(src_embeddings)%n_colors], label=word, s=1)
if plot_anchors:
anchor = np.array([np.mean(np.array(embs), axis=0)])
reduced_anchor = pca.transform(anchor)[0]
plt.plot(reduced_anchor[0], reduced_anchor[1], marker='X', color='k', ms=10)
plt.legend()
output_file = f'contextual_embeddings_{caption}.png' if caption else 'contextual_embeddings.png'
plt.savefig(os.path.join(output_dir, output_file))
plt.show()
plt.clf()
def plot_rotation(src_embeddings, tgt_embeddings, W, src_corpus, tgt_corpus, plot_anchors=False):
"""
src_embeddings - a dictionary from tokens to a list of contextual embeddings (source language)
tgt_embeddings - a dictionary from tokens to a list of contextual embeddings (target language)
W - a rotation to apply to source embeddings
src_corpus, tgt_corpus - strings for naming the image file
plot_anchors - whether to plot the mean of each token cloud (default: False)
---
Plot embeddings as points in a 2D space, before and after applying the rotation to the source embeddings.
"""
W_t = W.transpose()
src_embeddings_aligned = {}
for word, embs in src_embeddings.items():
embs = np.array(embs)
embs_al = np.matmul(embs, W_t)
src_embeddings_aligned[word] = list(embs_al)
print(len(list(embs_al)))
#fit pca
full_embs = []
for embs in src_embeddings.values():
full_embs += embs
for embs in tgt_embeddings.values():
full_embs += embs
for embs in src_embeddings_aligned.values():
full_embs += embs
full_embs = np.array(full_embs)
pca = PCA(n_components=2).fit(full_embs)
#plot before alignment
plot_clouds(src_embeddings, tgt_embeddings, pca=pca, plot_anchors=plot_anchors, caption=f'before_{src_corpus}_{tgt_corpus}')
#plot after alignment
plot_clouds(src_embeddings_aligned, tgt_embeddings, pca=pca, plot_anchors=plot_anchors, caption=f'after_{src_corpus}_{tgt_corpus}')
#with open ('src_saved', 'rb') as f:
# clouds = pickle.load(f)
##with open('src_sents', 'rb') as f:
## sents = pickle.load(f)
##plot_heatmap(list(clouds.keys())[0], clouds, sents)
#
#with open('tgt_saved', 'rb') as f:
# clouds_t = pickle.load(f)
#
#W = torch.load('alignments/fra.sdrt.annodis_eng.rst.gum.pth').detach().numpy()
#plot_rotation(clouds, clouds_t, W, plot_anchors=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment