diff --git a/.gitignore b/.gitignore index 8142d3669783655a35b7135fb4d7ef1db56a4ff2..c722019d59c7d8a5ffc1c5644442a7112c14d6e6 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,10 @@ models/CamemBERT-base/* !models/CamemBERT-base/.gitkeep models/Fine-tuned_CamemBERT-base/* !models/Fine-tuned_CamemBERT-base/.gitkeep +models/LegalCamemBERT-base/* +!models/LegalCamemBERT-base/.gitkeep +models/Fine-tuned_LegalCamemBERT-base/* +!models/Fine-tuned_LegalCamemBERT-base/.gitkeep modules/llm/Mixtral-8x7b/results/* modules/llm/Mixtral-8x7b/.lock_preprocessing diff --git a/models/download_model.py b/models/download_model.py index 26ac06316fb64ad545c458af981f98a62b221495..72530a12e91a5b23121fc9b4ea5ffb0ec4e7c14f 100644 --- a/models/download_model.py +++ b/models/download_model.py @@ -3,4 +3,5 @@ from huggingface_hub import snapshot_download #snapshot_download(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", cache_dir="../temp", local_dir="./Mixtral-8x7B-Instruct-v0.1") #snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.2", cache_dir="../temp", local_dir="./Mistral-7B-Instruct-v0.2") #snapshot_download(repo_id="152334H/miqu-1-70b-sf", cache_dir="../temp", local_dir="./Miqu-1-70b") -snapshot_download(repo_id="camembert/camembert-base", cache_dir="../temp", local_dir="./CamemBERT-base") \ No newline at end of file +#snapshot_download(repo_id="camembert/camembert-base", cache_dir="../temp", local_dir="./CamemBERT-base") +snapshot_download(repo_id="maastrichtlawtech/legal-camembert-base", cache_dir="../temp", local_dir="./LegalCamemBERT-base") \ No newline at end of file diff --git a/modules/camembert/finetuned-legal-camembert.py b/modules/bert/finetuned-legal-camembert.py similarity index 100% rename from modules/camembert/finetuned-legal-camembert.py rename to modules/bert/finetuned-legal-camembert.py diff --git a/modules/bert/finetuning-legal-bert-classifier.ipynb b/modules/bert/finetuning-legal-bert-classifier.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c1040758bffdff4f4ff5befd67d8aec278ca3cd6 --- /dev/null +++ b/modules/bert/finetuning-legal-bert-classifier.ipynb @@ -0,0 +1,2222 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8ced2e3ca31fb46c", + "metadata": {}, + "source": [ + "# Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "757a8bf026156e77", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:45:25.410725Z", + "start_time": "2024-06-27T14:45:25.404357Z" + } + }, + "outputs": [], + "source": [ + "tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, 'time': 8}\n", + "id2tag = {v:k for k, v in tag2id.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "be3a4c320f9d4a5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:45:35.889919Z", + "start_time": "2024-06-27T14:45:35.885841Z" + }, + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "label2id = {\n", + " 'O': 0,\n", + " **{f'{k}': v for k, v in tag2id.items()}\n", + "}\n", + "\n", + "id2label = {v:k for k, v in label2id.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2aa2fefac95e7f04", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:45:46.507242Z", + "start_time": "2024-06-27T14:45:38.632569Z" + } + }, + "outputs": [], + "source": [ + "from datasets import Dataset\n", + "train_ds = Dataset.from_json(\"../../data/annotations.train.jsonlines\")\n", + "val_ds = Dataset.from_json(\"../../data/annotations.eval.jsonlines\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9e0a21356e7701a1", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:45:52.188551Z", + "start_time": "2024-06-27T14:45:52.185648Z" + } + }, + "outputs": [], + "source": [ + "modelId = '../../models/LegalCamemBERT-base'" + ] + }, + { + "cell_type": "markdown", + "id": "66e00d5a79a66753", + "metadata": {}, + "source": [ + "# Tokenization" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e6459259f5ab2d98", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:46:02.189231Z", + "start_time": "2024-06-27T14:45:56.737350Z" + } + }, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(modelId)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8c96680645f077fb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:46:08.307140Z", + "start_time": "2024-06-27T14:46:08.301542Z" + } + }, + "outputs": [], + "source": [ + "def get_token_role_in_span(token_start: int, token_end: int, span_start: int, span_end: int):\n", + " \"\"\"\n", + " Check if the token is inside a span.\n", + " Args:\n", + " - token_start, token_end: Start and end offset of the token\n", + " - span_start, span_end: Start and end of the span\n", + " Returns:\n", + " - \"B\" if beginning\n", + " - \"I\" if inner\n", + " - \"O\" if outer\n", + " - \"N\" if not valid token (like <SEP>, <CLS>, <UNK>)\n", + " \"\"\"\n", + " if token_end <= token_start:\n", + " return \"N\"\n", + " if token_start < span_start or token_end > span_end:\n", + " return \"O\"\n", + " else:\n", + " return \"I\"\n", + "\n", + "MAX_LENGTH = 256\n", + "\n", + "def tokenize_and_adjust_labels(sample):\n", + " \"\"\"\n", + " Args:\n", + " - sample (dict): {\"id\": \"...\", \"text\": \"...\", \"tags\": [{\"start\": ..., \"end\": ..., \"tag\": ...}, ...]\n", + " Returns:\n", + " - The tokenized version of `sample` and the labels of each token.\n", + " \"\"\"\n", + " # Tokenize the text, keep the start and end positions of tokens with `return_offsets_mapping` option\n", + " # Use max_length and truncation to ajust the text length\n", + " tokenized = tokenizer(sample[\"text\"],\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " max_length=MAX_LENGTH,\n", + " truncation=True)\n", + "\n", + " # We are doing a multilabel classification task at each token, we create a list of size len(label2id)=13 \n", + " # for the 13 labels\n", + " labels = [[0 for _ in label2id.keys()] for _ in range(MAX_LENGTH)]\n", + "\n", + " # Scan all the tokens and spans, assign 1 to the corresponding label if the token lies at the beginning\n", + " # or inside the spans\n", + " for (token_start, token_end), token_labels in zip(tokenized[\"offset_mapping\"], labels):\n", + " #print(token_start, token_end)\n", + " for span in sample[\"tags\"]:\n", + " role = get_token_role_in_span(token_start, token_end, span[\"start\"], span[\"end\"])\n", + " if role == \"I\":\n", + " token_labels[label2id[f\"{span['tag']}\"]] = 1\n", + "\n", + " return {**tokenized, \"labels\": labels}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "53310845f13e9d70", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-27T14:46:13.910535Z", + "start_time": "2024-06-27T14:46:13.689422Z" + } + }, + "outputs": [], + "source": [ + "tokenized_train_ds = train_ds.map(tokenize_and_adjust_labels, remove_columns=train_ds.column_names)\n", + "tokenized_val_ds = val_ds.map(tokenize_and_adjust_labels, remove_columns=val_ds.column_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a654180a-a536-43fa-8984-c7b87b419f93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------Token---------|--------Offset----------|--------Labels----------\n", + " <s> | [0, 0] | []\n", + " Est | [0, 3] | ['action']\n", + " puni | [4, 8] | ['action']\n", + " des | [9, 12] | ['action']\n", + " mêmes | [13, 18] | ['action']\n", + " peines | [19, 25] | ['action']\n", + " le | [26, 28] | []\n", + " fait | [29, 33] | []\n", + " de | [34, 36] | []\n", + " tolérer | [37, 44] | []\n", + " comme | [45, 50] | []\n", + " propriétaire | [51, 63] | ['actor']\n", + " ou | [64, 66] | ['actor']\n", + " détenteur | [67, 76] | ['actor']\n", + " la | [77, 79] | []\n", + " mise | [80, 84] | []\n", + " en | [85, 87] | []\n", + " circulation | [88, 99] | []\n", + " d | [100, 101] | []\n", + " | [102, 103] | []\n", + " ' | [102, 103] | []\n", + " un | [104, 106] | []\n", + " véhicule | [107, 115] | []\n", + " sur | [116, 119] | ['location']\n", + " les | [120, 123] | ['location']\n", + " voies | [124, 129] | ['location']\n", + " publiques | [130, 139] | ['location']\n", + " par | [140, 143] | []\n", + " une | [144, 147] | ['actor']\n", + " personne | [148, 156] | ['actor']\n", + " non | [157, 160] | ['condition']\n", + " titulaire | [161, 170] | ['condition']\n", + " d | [171, 172] | ['condition']\n", + " | [173, 174] | ['condition']\n", + " ' | [173, 174] | ['condition']\n", + " un | [175, 177] | ['condition']\n", + " permis | [178, 184] | ['condition']\n", + " de | [185, 187] | ['condition']\n", + " conduire | [188, 196] | ['condition']\n", + " valable | [197, 204] | ['condition']\n", + " | [205, 206] | []\n", + " . | [205, 206] | []\n", + " </s> | [0, 0] | []\n" + ] + } + ], + "source": [ + "sample = tokenized_train_ds[0]\n", + "print(\"--------Token---------|--------Offset----------|--------Labels----------\")\n", + "for token_id, token_labels, offset in zip(sample[\"input_ids\"], sample[\"labels\"], sample[\"offset_mapping\"]):\n", + " # Decode the token_id into text\n", + " token_text = tokenizer.decode(token_id)\n", + " \n", + " # Retrieve all the indices corresponding to the \"1\" at each token, decode them to label name\n", + " labels = [id2label[label_index] for label_index, value in enumerate(token_labels) if value==1]\n", + " \n", + " # Decode those indices into label name\n", + " print(f\" {token_text:20} | {offset} | {labels}\")\n", + " \n", + " # Finish when we meet the end of sentence.\n", + " if token_text == \"</s>\": \n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6990d89800dbb440", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import DataCollatorWithPadding\n", + "data_collator = DataCollatorWithPadding(tokenizer, padding=True)" + ] + }, + { + "cell_type": "markdown", + "id": "668dcf9750404d1c", + "metadata": {}, + "source": [ + "# Adapt the model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7bd0cddab7ddb448", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.metrics import multilabel_confusion_matrix\n", + "\n", + "n_labels = len(id2label)\n", + "\n", + "def divide(a: int, b: int):\n", + " return a / b if b > 0 else 0\n", + "\n", + "def compute_metrics(p):\n", + " \"\"\"\n", + " Customize the `compute_metrics` of `transformers`\n", + " Args:\n", + " - p (tuple): 2 numpy arrays: predictions and true_labels\n", + " Returns:\n", + " - metrics (dict): f1 score on \n", + " \"\"\"\n", + " # (1)\n", + " predictions, true_labels = p\n", + "\n", + " # (2)\n", + " predicted_labels = np.where(predictions > 0, np.ones(predictions.shape), np.zeros(predictions.shape))\n", + " metrics = {}\n", + "\n", + " # (3)\n", + " cm = multilabel_confusion_matrix(true_labels.reshape(-1, n_labels), predicted_labels.reshape(-1, n_labels))\n", + "\n", + " # (4) \n", + " for label_idx, matrix in enumerate(cm):\n", + " if label_idx == 0:\n", + " continue # We don't care about the label \"O\"\n", + " tp, fp, fn = matrix[1, 1], matrix[0, 1], matrix[1, 0]\n", + " precision = divide(tp, tp + fp)\n", + " recall = divide(tp, tp + fn)\n", + " f1 = divide(2 * precision * recall, precision + recall)\n", + " metrics[f\"f1_{id2label[label_idx]}\"] = f1\n", + "\n", + " # (5)\n", + " macro_f1 = sum(list(metrics.values())) / (n_labels - 1)\n", + " metrics[\"macro_f1\"] = macro_f1\n", + "\n", + " return metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ea5d16f59728e2b9", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, TrainerCallback\n", + "from transformers import RobertaPreTrainedModel, RobertaModel\n", + "from transformers.utils import (\n", + " add_code_sample_docstrings,\n", + " add_start_docstrings,\n", + " add_start_docstrings_to_model_forward,\n", + " logging,\n", + " replace_return_docstrings,\n", + ")\n", + "from transformers.models.roberta.modeling_roberta import (\n", + " ROBERTA_INPUTS_DOCSTRING,\n", + " ROBERTA_START_DOCSTRING,\n", + " RobertaEmbeddings,\n", + ")\n", + "from typing import Optional, Union, Tuple\n", + "from transformers.modeling_outputs import TokenClassifierOutput\n", + "import torch\n", + "from torch import nn\n", + "import matplotlib.pyplot as plt\n", + "from transformers.trainer_utils import IntervalStrategy\n", + "import pandas as pd\n", + "\n", + "class RobertaForSpanCategorization(RobertaPreTrainedModel):\n", + " _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n", + " _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n", + "\n", + " def __init__(self, config):\n", + " super().__init__(config)\n", + " self.num_labels = config.num_labels\n", + " self.roberta = RobertaModel(config, add_pooling_layer=False)\n", + " classifier_dropout = (\n", + " config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob\n", + " )\n", + " self.dropout = nn.Dropout(classifier_dropout)\n", + " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n", + " # Initialize weights and apply final processing\n", + " self.post_init()\n", + "\n", + " @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n", + " def forward(\n", + " self,\n", + " input_ids: Optional[torch.LongTensor] = None,\n", + " attention_mask: Optional[torch.FloatTensor] = None,\n", + " token_type_ids: Optional[torch.LongTensor] = None,\n", + " position_ids: Optional[torch.LongTensor] = None,\n", + " head_mask: Optional[torch.FloatTensor] = None,\n", + " inputs_embeds: Optional[torch.FloatTensor] = None,\n", + " labels: Optional[torch.LongTensor] = None,\n", + " output_attentions: Optional[bool] = None,\n", + " output_hidden_states: Optional[bool] = None,\n", + " return_dict: Optional[bool] = None,\n", + " ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:\n", + " r\"\"\"\n", + " labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n", + " Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.\n", + " \"\"\"\n", + " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", + " outputs = self.roberta(\n", + " input_ids,\n", + " attention_mask=attention_mask,\n", + " token_type_ids=token_type_ids,\n", + " position_ids=position_ids,\n", + " head_mask=head_mask,\n", + " inputs_embeds=inputs_embeds,\n", + " output_attentions=output_attentions,\n", + " output_hidden_states=output_hidden_states,\n", + " return_dict=return_dict,\n", + " )\n", + " sequence_output = outputs[0]\n", + " sequence_output = self.dropout(sequence_output)\n", + " logits = self.classifier(sequence_output)\n", + "\n", + " loss = None\n", + " if labels is not None:\n", + " loss_fct = nn.BCEWithLogitsLoss()\n", + " loss = loss_fct(logits, labels.float())\n", + " if not return_dict:\n", + " output = (logits,) + outputs[2:]\n", + " return ((loss,) + output) if loss is not None else output\n", + " return TokenClassifierOutput(\n", + " loss=loss,\n", + " logits=logits,\n", + " hidden_states=outputs.hidden_states,\n", + " attentions=outputs.attentions,\n", + " )\n", + "\n", + "class TrainingMetricsCallback(TrainerCallback):\n", + " def __init__(self):\n", + " self.macro_f1 = []\n", + " self.steps = []\n", + " self.counter = 0\n", + "\n", + " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n", + " if metrics is not None:\n", + " if 'eval_macro_f1' in metrics:\n", + " self.macro_f1.append(metrics['eval_macro_f1'])\n", + " self.counter += 1\n", + " self.steps.append(self.counter)" + ] + }, + { + "cell_type": "markdown", + "id": "77f4fc68394aa754", + "metadata": {}, + "source": [ + "# Fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "79161ed938cad895", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"./models/fine_tune_bert_output_span_cat\",\n", + " evaluation_strategy=\"epoch\",\n", + " learning_rate=2.5e-4,\n", + " per_device_train_batch_size=16,\n", + " per_device_eval_batch_size=16,\n", + " num_train_epochs=100,\n", + " weight_decay=0.01,\n", + " logging_steps = 100,\n", + " save_strategy='epoch',\n", + " save_total_limit=2,\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model='macro_f1',\n", + " log_level='critical',\n", + " seed=12345\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "931792b554582a9f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " <div>\n", + " \n", + " <progress value='1300' max='1300' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", + " [1300/1300 09:06, Epoch 100/100]\n", + " </div>\n", + " <table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: left;\">\n", + " <th>Epoch</th>\n", + " <th>Training Loss</th>\n", + " <th>Validation Loss</th>\n", + " <th>F1 Action</th>\n", + " <th>F1 Actor</th>\n", + " <th>F1 Artifact</th>\n", + " <th>F1 Condition</th>\n", + " <th>F1 Location</th>\n", + " <th>F1 Modality</th>\n", + " <th>F1 Reference</th>\n", + " <th>F1 Time</th>\n", + " <th>Macro F1</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>1</td>\n", + " <td>No log</td>\n", + " <td>0.356145</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>2</td>\n", + " <td>No log</td>\n", + " <td>0.279355</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>3</td>\n", + " <td>No log</td>\n", + " <td>0.221275</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>4</td>\n", + " <td>No log</td>\n", + " <td>0.178497</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>5</td>\n", + " <td>No log</td>\n", + " <td>0.149304</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>6</td>\n", + " <td>No log</td>\n", + " <td>0.127901</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>7</td>\n", + " <td>No log</td>\n", + " <td>0.114968</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>8</td>\n", + " <td>0.213000</td>\n", + " <td>0.102522</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <td>9</td>\n", + " <td>0.213000</td>\n", + " <td>0.097100</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.176965</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.022121</td>\n", + " </tr>\n", + " <tr>\n", + " <td>10</td>\n", + " <td>0.213000</td>\n", + " <td>0.090857</td>\n", + " <td>0.346816</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.779005</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.140728</td>\n", + " </tr>\n", + " <tr>\n", + " <td>11</td>\n", + " <td>0.213000</td>\n", + " <td>0.086320</td>\n", + " <td>0.157806</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.771041</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.116106</td>\n", + " </tr>\n", + " <tr>\n", + " <td>12</td>\n", + " <td>0.213000</td>\n", + " <td>0.083295</td>\n", + " <td>0.398428</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.749960</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.143549</td>\n", + " </tr>\n", + " <tr>\n", + " <td>13</td>\n", + " <td>0.213000</td>\n", + " <td>0.082080</td>\n", + " <td>0.347058</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.779387</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.140806</td>\n", + " </tr>\n", + " <tr>\n", + " <td>14</td>\n", + " <td>0.213000</td>\n", + " <td>0.079545</td>\n", + " <td>0.432345</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.787969</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.152539</td>\n", + " </tr>\n", + " <tr>\n", + " <td>15</td>\n", + " <td>0.213000</td>\n", + " <td>0.076363</td>\n", + " <td>0.371108</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.781354</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.144058</td>\n", + " </tr>\n", + " <tr>\n", + " <td>16</td>\n", + " <td>0.066000</td>\n", + " <td>0.075020</td>\n", + " <td>0.554452</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.759754</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.164276</td>\n", + " </tr>\n", + " <tr>\n", + " <td>17</td>\n", + " <td>0.066000</td>\n", + " <td>0.076133</td>\n", + " <td>0.518700</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.783399</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.162762</td>\n", + " </tr>\n", + " <tr>\n", + " <td>18</td>\n", + " <td>0.066000</td>\n", + " <td>0.072838</td>\n", + " <td>0.531309</td>\n", + " <td>0.197055</td>\n", + " <td>0</td>\n", + " <td>0.786690</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.540024</td>\n", + " <td>0</td>\n", + " <td>0.256885</td>\n", + " </tr>\n", + " <tr>\n", + " <td>19</td>\n", + " <td>0.066000</td>\n", + " <td>0.075695</td>\n", + " <td>0.445422</td>\n", + " <td>0.035851</td>\n", + " <td>0</td>\n", + " <td>0.733345</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.315789</td>\n", + " <td>0</td>\n", + " <td>0.191301</td>\n", + " </tr>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>0.066000</td>\n", + " <td>0.072801</td>\n", + " <td>0.506315</td>\n", + " <td>0.060453</td>\n", + " <td>0</td>\n", + " <td>0.762457</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.600897</td>\n", + " <td>0</td>\n", + " <td>0.241265</td>\n", + " </tr>\n", + " <tr>\n", + " <td>21</td>\n", + " <td>0.066000</td>\n", + " <td>0.072491</td>\n", + " <td>0.466798</td>\n", + " <td>0.264026</td>\n", + " <td>0</td>\n", + " <td>0.779796</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.583524</td>\n", + " <td>0</td>\n", + " <td>0.261768</td>\n", + " </tr>\n", + " <tr>\n", + " <td>22</td>\n", + " <td>0.066000</td>\n", + " <td>0.068607</td>\n", + " <td>0.531232</td>\n", + " <td>0.440191</td>\n", + " <td>0.034570</td>\n", + " <td>0.786928</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.707031</td>\n", + " <td>0</td>\n", + " <td>0.312494</td>\n", + " </tr>\n", + " <tr>\n", + " <td>23</td>\n", + " <td>0.066000</td>\n", + " <td>0.067250</td>\n", + " <td>0.503735</td>\n", + " <td>0.086420</td>\n", + " <td>0.001042</td>\n", + " <td>0.797254</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.713208</td>\n", + " <td>0.208232</td>\n", + " <td>0.288736</td>\n", + " </tr>\n", + " <tr>\n", + " <td>24</td>\n", + " <td>0.039600</td>\n", + " <td>0.069559</td>\n", + " <td>0.518097</td>\n", + " <td>0.466793</td>\n", + " <td>0.047141</td>\n", + " <td>0.787783</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.608506</td>\n", + " <td>0.204489</td>\n", + " <td>0.329101</td>\n", + " </tr>\n", + " <tr>\n", + " <td>25</td>\n", + " <td>0.039600</td>\n", + " <td>0.068075</td>\n", + " <td>0.486739</td>\n", + " <td>0.497164</td>\n", + " <td>0.061265</td>\n", + " <td>0.774566</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.736209</td>\n", + " <td>0.555336</td>\n", + " <td>0.388910</td>\n", + " </tr>\n", + " <tr>\n", + " <td>26</td>\n", + " <td>0.039600</td>\n", + " <td>0.065591</td>\n", + " <td>0.526296</td>\n", + " <td>0.539711</td>\n", + " <td>0.176184</td>\n", + " <td>0.803005</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.743191</td>\n", + " <td>0.538232</td>\n", + " <td>0.415827</td>\n", + " </tr>\n", + " <tr>\n", + " <td>27</td>\n", + " <td>0.039600</td>\n", + " <td>0.073091</td>\n", + " <td>0.450490</td>\n", + " <td>0.438735</td>\n", + " <td>0.021572</td>\n", + " <td>0.770273</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.600897</td>\n", + " <td>0.234597</td>\n", + " <td>0.314570</td>\n", + " </tr>\n", + " <tr>\n", + " <td>28</td>\n", + " <td>0.039600</td>\n", + " <td>0.067930</td>\n", + " <td>0.492051</td>\n", + " <td>0.509091</td>\n", + " <td>0.106906</td>\n", + " <td>0.803842</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0.674923</td>\n", + " <td>0.699515</td>\n", + " <td>0.410791</td>\n", + " </tr>\n", + " <tr>\n", + " <td>29</td>\n", + " <td>0.039600</td>\n", + " <td>0.064891</td>\n", + " <td>0.517025</td>\n", + " <td>0.592593</td>\n", + " <td>0.030334</td>\n", + " <td>0.820184</td>\n", + " <td>0</td>\n", + " <td>0.013333</td>\n", + " <td>0.713996</td>\n", + " <td>0.685200</td>\n", + " <td>0.421583</td>\n", + " </tr>\n", + " <tr>\n", + " <td>30</td>\n", + " <td>0.039600</td>\n", + " <td>0.063571</td>\n", + " <td>0.558310</td>\n", + " <td>0.576125</td>\n", + " <td>0.140530</td>\n", + " <td>0.803007</td>\n", + " <td>0</td>\n", + " <td>0.215569</td>\n", + " <td>0.723773</td>\n", + " <td>0.666118</td>\n", + " <td>0.460429</td>\n", + " </tr>\n", + " <tr>\n", + " <td>31</td>\n", + " <td>0.025000</td>\n", + " <td>0.064738</td>\n", + " <td>0.525939</td>\n", + " <td>0.589438</td>\n", + " <td>0.172210</td>\n", + " <td>0.818668</td>\n", + " <td>0</td>\n", + " <td>0.482412</td>\n", + " <td>0.718967</td>\n", + " <td>0.702580</td>\n", + " <td>0.501277</td>\n", + " </tr>\n", + " <tr>\n", + " <td>32</td>\n", + " <td>0.025000</td>\n", + " <td>0.065563</td>\n", + " <td>0.549824</td>\n", + " <td>0.636741</td>\n", + " <td>0.091671</td>\n", + " <td>0.806373</td>\n", + " <td>0</td>\n", + " <td>0.548077</td>\n", + " <td>0.687243</td>\n", + " <td>0.676329</td>\n", + " <td>0.499532</td>\n", + " </tr>\n", + " <tr>\n", + " <td>33</td>\n", + " <td>0.025000</td>\n", + " <td>0.062232</td>\n", + " <td>0.558872</td>\n", + " <td>0.690293</td>\n", + " <td>0.130246</td>\n", + " <td>0.823770</td>\n", + " <td>0.153846</td>\n", + " <td>0.561905</td>\n", + " <td>0.744412</td>\n", + " <td>0.719551</td>\n", + " <td>0.547862</td>\n", + " </tr>\n", + " <tr>\n", + " <td>34</td>\n", + " <td>0.025000</td>\n", + " <td>0.063367</td>\n", + " <td>0.542108</td>\n", + " <td>0.695394</td>\n", + " <td>0.170620</td>\n", + " <td>0.828333</td>\n", + " <td>0.253968</td>\n", + " <td>0.583333</td>\n", + " <td>0.734452</td>\n", + " <td>0.733387</td>\n", + " <td>0.567700</td>\n", + " </tr>\n", + " <tr>\n", + " <td>35</td>\n", + " <td>0.025000</td>\n", + " <td>0.061890</td>\n", + " <td>0.562546</td>\n", + " <td>0.707951</td>\n", + " <td>0.167960</td>\n", + " <td>0.821657</td>\n", + " <td>0.480769</td>\n", + " <td>0.574074</td>\n", + " <td>0.732314</td>\n", + " <td>0.736597</td>\n", + " <td>0.597983</td>\n", + " </tr>\n", + " <tr>\n", + " <td>36</td>\n", + " <td>0.025000</td>\n", + " <td>0.060461</td>\n", + " <td>0.589995</td>\n", + " <td>0.683121</td>\n", + " <td>0.262919</td>\n", + " <td>0.822199</td>\n", + " <td>0.517799</td>\n", + " <td>0.580645</td>\n", + " <td>0.736641</td>\n", + " <td>0.745312</td>\n", + " <td>0.617329</td>\n", + " </tr>\n", + " <tr>\n", + " <td>37</td>\n", + " <td>0.025000</td>\n", + " <td>0.064956</td>\n", + " <td>0.548047</td>\n", + " <td>0.685271</td>\n", + " <td>0.175471</td>\n", + " <td>0.827378</td>\n", + " <td>0.547771</td>\n", + " <td>0.618834</td>\n", + " <td>0.742968</td>\n", + " <td>0.736089</td>\n", + " <td>0.610229</td>\n", + " </tr>\n", + " <tr>\n", + " <td>38</td>\n", + " <td>0.025000</td>\n", + " <td>0.063804</td>\n", + " <td>0.539435</td>\n", + " <td>0.701887</td>\n", + " <td>0.270973</td>\n", + " <td>0.813606</td>\n", + " <td>0.541966</td>\n", + " <td>0.600000</td>\n", + " <td>0.745540</td>\n", + " <td>0.732331</td>\n", + " <td>0.618217</td>\n", + " </tr>\n", + " <tr>\n", + " <td>39</td>\n", + " <td>0.015200</td>\n", + " <td>0.067724</td>\n", + " <td>0.533308</td>\n", + " <td>0.642036</td>\n", + " <td>0.190905</td>\n", + " <td>0.821376</td>\n", + " <td>0.548896</td>\n", + " <td>0.600000</td>\n", + " <td>0.743961</td>\n", + " <td>0.718573</td>\n", + " <td>0.599882</td>\n", + " </tr>\n", + " <tr>\n", + " <td>40</td>\n", + " <td>0.015200</td>\n", + " <td>0.068417</td>\n", + " <td>0.533283</td>\n", + " <td>0.650718</td>\n", + " <td>0.182653</td>\n", + " <td>0.812395</td>\n", + " <td>0.601190</td>\n", + " <td>0.593607</td>\n", + " <td>0.744186</td>\n", + " <td>0.722561</td>\n", + " <td>0.605074</td>\n", + " </tr>\n", + " <tr>\n", + " <td>41</td>\n", + " <td>0.015200</td>\n", + " <td>0.068958</td>\n", + " <td>0.546763</td>\n", + " <td>0.678769</td>\n", + " <td>0.234627</td>\n", + " <td>0.772878</td>\n", + " <td>0.596859</td>\n", + " <td>0.593607</td>\n", + " <td>0.757374</td>\n", + " <td>0.730038</td>\n", + " <td>0.613864</td>\n", + " </tr>\n", + " <tr>\n", + " <td>42</td>\n", + " <td>0.015200</td>\n", + " <td>0.067570</td>\n", + " <td>0.550757</td>\n", + " <td>0.704247</td>\n", + " <td>0.213687</td>\n", + " <td>0.826460</td>\n", + " <td>0.561404</td>\n", + " <td>0.606335</td>\n", + " <td>0.748047</td>\n", + " <td>0.718676</td>\n", + " <td>0.616201</td>\n", + " </tr>\n", + " <tr>\n", + " <td>43</td>\n", + " <td>0.015200</td>\n", + " <td>0.069913</td>\n", + " <td>0.527344</td>\n", + " <td>0.673583</td>\n", + " <td>0.235346</td>\n", + " <td>0.813295</td>\n", + " <td>0.568966</td>\n", + " <td>0.593607</td>\n", + " <td>0.762808</td>\n", + " <td>0.707539</td>\n", + " <td>0.610311</td>\n", + " </tr>\n", + " <tr>\n", + " <td>44</td>\n", + " <td>0.015200</td>\n", + " <td>0.066919</td>\n", + " <td>0.548339</td>\n", + " <td>0.702537</td>\n", + " <td>0.250538</td>\n", + " <td>0.811425</td>\n", + " <td>0.594164</td>\n", + " <td>0.576744</td>\n", + " <td>0.759615</td>\n", + " <td>0.744470</td>\n", + " <td>0.623479</td>\n", + " </tr>\n", + " <tr>\n", + " <td>45</td>\n", + " <td>0.015200</td>\n", + " <td>0.070344</td>\n", + " <td>0.538619</td>\n", + " <td>0.684670</td>\n", + " <td>0.205105</td>\n", + " <td>0.812809</td>\n", + " <td>0.589474</td>\n", + " <td>0.589862</td>\n", + " <td>0.756964</td>\n", + " <td>0.735669</td>\n", + " <td>0.614147</td>\n", + " </tr>\n", + " <tr>\n", + " <td>46</td>\n", + " <td>0.015200</td>\n", + " <td>0.071540</td>\n", + " <td>0.509841</td>\n", + " <td>0.705255</td>\n", + " <td>0.239027</td>\n", + " <td>0.813519</td>\n", + " <td>0.552279</td>\n", + " <td>0.603604</td>\n", + " <td>0.756757</td>\n", + " <td>0.727416</td>\n", + " <td>0.613462</td>\n", + " </tr>\n", + " <tr>\n", + " <td>47</td>\n", + " <td>0.010700</td>\n", + " <td>0.070772</td>\n", + " <td>0.548339</td>\n", + " <td>0.693146</td>\n", + " <td>0.244704</td>\n", + " <td>0.818304</td>\n", + " <td>0.559767</td>\n", + " <td>0.587156</td>\n", + " <td>0.756757</td>\n", + " <td>0.723200</td>\n", + " <td>0.616422</td>\n", + " </tr>\n", + " <tr>\n", + " <td>48</td>\n", + " <td>0.010700</td>\n", + " <td>0.070882</td>\n", + " <td>0.547571</td>\n", + " <td>0.685490</td>\n", + " <td>0.221344</td>\n", + " <td>0.810738</td>\n", + " <td>0.541176</td>\n", + " <td>0.606335</td>\n", + " <td>0.762357</td>\n", + " <td>0.725745</td>\n", + " <td>0.612595</td>\n", + " </tr>\n", + " <tr>\n", + " <td>49</td>\n", + " <td>0.010700</td>\n", + " <td>0.070016</td>\n", + " <td>0.543567</td>\n", + " <td>0.702128</td>\n", + " <td>0.270270</td>\n", + " <td>0.810131</td>\n", + " <td>0.594059</td>\n", + " <td>0.606335</td>\n", + " <td>0.759924</td>\n", + " <td>0.748638</td>\n", + " <td>0.629382</td>\n", + " </tr>\n", + " <tr>\n", + " <td>50</td>\n", + " <td>0.010700</td>\n", + " <td>0.072628</td>\n", + " <td>0.542216</td>\n", + " <td>0.679186</td>\n", + " <td>0.249678</td>\n", + " <td>0.808700</td>\n", + " <td>0.567568</td>\n", + " <td>0.616071</td>\n", + " <td>0.763810</td>\n", + " <td>0.738609</td>\n", + " <td>0.620730</td>\n", + " </tr>\n", + " <tr>\n", + " <td>51</td>\n", + " <td>0.010700</td>\n", + " <td>0.071187</td>\n", + " <td>0.553411</td>\n", + " <td>0.674033</td>\n", + " <td>0.268354</td>\n", + " <td>0.805851</td>\n", + " <td>0.582888</td>\n", + " <td>0.600000</td>\n", + " <td>0.760994</td>\n", + " <td>0.726422</td>\n", + " <td>0.621494</td>\n", + " </tr>\n", + " <tr>\n", + " <td>52</td>\n", + " <td>0.010700</td>\n", + " <td>0.071963</td>\n", + " <td>0.550415</td>\n", + " <td>0.685848</td>\n", + " <td>0.259574</td>\n", + " <td>0.813526</td>\n", + " <td>0.593750</td>\n", + " <td>0.584475</td>\n", + " <td>0.757116</td>\n", + " <td>0.727129</td>\n", + " <td>0.621479</td>\n", + " </tr>\n", + " <tr>\n", + " <td>53</td>\n", + " <td>0.010700</td>\n", + " <td>0.071476</td>\n", + " <td>0.561372</td>\n", + " <td>0.694253</td>\n", + " <td>0.276007</td>\n", + " <td>0.809990</td>\n", + " <td>0.573727</td>\n", + " <td>0.609865</td>\n", + " <td>0.760994</td>\n", + " <td>0.745843</td>\n", + " <td>0.629006</td>\n", + " </tr>\n", + " <tr>\n", + " <td>54</td>\n", + " <td>0.008000</td>\n", + " <td>0.072140</td>\n", + " <td>0.555779</td>\n", + " <td>0.685848</td>\n", + " <td>0.263158</td>\n", + " <td>0.809853</td>\n", + " <td>0.584767</td>\n", + " <td>0.587156</td>\n", + " <td>0.759661</td>\n", + " <td>0.737405</td>\n", + " <td>0.622953</td>\n", + " </tr>\n", + " <tr>\n", + " <td>55</td>\n", + " <td>0.008000</td>\n", + " <td>0.071148</td>\n", + " <td>0.563626</td>\n", + " <td>0.693846</td>\n", + " <td>0.268644</td>\n", + " <td>0.812698</td>\n", + " <td>0.614583</td>\n", + " <td>0.669528</td>\n", + " <td>0.753946</td>\n", + " <td>0.736518</td>\n", + " <td>0.639174</td>\n", + " </tr>\n", + " <tr>\n", + " <td>56</td>\n", + " <td>0.008000</td>\n", + " <td>0.072347</td>\n", + " <td>0.563206</td>\n", + " <td>0.685490</td>\n", + " <td>0.278057</td>\n", + " <td>0.813593</td>\n", + " <td>0.597468</td>\n", + " <td>0.606335</td>\n", + " <td>0.753541</td>\n", + " <td>0.732778</td>\n", + " <td>0.628809</td>\n", + " </tr>\n", + " <tr>\n", + " <td>57</td>\n", + " <td>0.008000</td>\n", + " <td>0.072100</td>\n", + " <td>0.552941</td>\n", + " <td>0.684251</td>\n", + " <td>0.298728</td>\n", + " <td>0.821290</td>\n", + " <td>0.596859</td>\n", + " <td>0.612613</td>\n", + " <td>0.752363</td>\n", + " <td>0.731707</td>\n", + " <td>0.631344</td>\n", + " </tr>\n", + " <tr>\n", + " <td>58</td>\n", + " <td>0.008000</td>\n", + " <td>0.072843</td>\n", + " <td>0.549244</td>\n", + " <td>0.699237</td>\n", + " <td>0.290456</td>\n", + " <td>0.821850</td>\n", + " <td>0.588542</td>\n", + " <td>0.603604</td>\n", + " <td>0.751174</td>\n", + " <td>0.732103</td>\n", + " <td>0.629526</td>\n", + " </tr>\n", + " <tr>\n", + " <td>59</td>\n", + " <td>0.008000</td>\n", + " <td>0.073446</td>\n", + " <td>0.545723</td>\n", + " <td>0.706677</td>\n", + " <td>0.277801</td>\n", + " <td>0.808947</td>\n", + " <td>0.590674</td>\n", + " <td>0.606335</td>\n", + " <td>0.760037</td>\n", + " <td>0.731915</td>\n", + " <td>0.628514</td>\n", + " </tr>\n", + " <tr>\n", + " <td>60</td>\n", + " <td>0.008000</td>\n", + " <td>0.068292</td>\n", + " <td>0.571429</td>\n", + " <td>0.710546</td>\n", + " <td>0.323144</td>\n", + " <td>0.806643</td>\n", + " <td>0.625323</td>\n", + " <td>0.596330</td>\n", + " <td>0.753176</td>\n", + " <td>0.735294</td>\n", + " <td>0.640236</td>\n", + " </tr>\n", + " <tr>\n", + " <td>61</td>\n", + " <td>0.008000</td>\n", + " <td>0.069890</td>\n", + " <td>0.621654</td>\n", + " <td>0.701807</td>\n", + " <td>0.293851</td>\n", + " <td>0.789474</td>\n", + " <td>0.576819</td>\n", + " <td>0.666667</td>\n", + " <td>0.747331</td>\n", + " <td>0.627998</td>\n", + " <td>0.628200</td>\n", + " </tr>\n", + " <tr>\n", + " <td>62</td>\n", + " <td>0.006900</td>\n", + " <td>0.077804</td>\n", + " <td>0.556911</td>\n", + " <td>0.691603</td>\n", + " <td>0.256562</td>\n", + " <td>0.792987</td>\n", + " <td>0.570605</td>\n", + " <td>0.600000</td>\n", + " <td>0.774443</td>\n", + " <td>0.664879</td>\n", + " <td>0.613499</td>\n", + " </tr>\n", + " <tr>\n", + " <td>63</td>\n", + " <td>0.006900</td>\n", + " <td>0.075310</td>\n", + " <td>0.545927</td>\n", + " <td>0.676012</td>\n", + " <td>0.246670</td>\n", + " <td>0.813548</td>\n", + " <td>0.587601</td>\n", + " <td>0.602740</td>\n", + " <td>0.693416</td>\n", + " <td>0.682464</td>\n", + " <td>0.606047</td>\n", + " </tr>\n", + " <tr>\n", + " <td>64</td>\n", + " <td>0.006900</td>\n", + " <td>0.075709</td>\n", + " <td>0.567001</td>\n", + " <td>0.693215</td>\n", + " <td>0.254287</td>\n", + " <td>0.810577</td>\n", + " <td>0.600567</td>\n", + " <td>0.600000</td>\n", + " <td>0.779727</td>\n", + " <td>0.733489</td>\n", + " <td>0.629858</td>\n", + " </tr>\n", + " <tr>\n", + " <td>65</td>\n", + " <td>0.006900</td>\n", + " <td>0.074874</td>\n", + " <td>0.555064</td>\n", + " <td>0.700515</td>\n", + " <td>0.267559</td>\n", + " <td>0.807587</td>\n", + " <td>0.584856</td>\n", + " <td>0.600000</td>\n", + " <td>0.785579</td>\n", + " <td>0.719940</td>\n", + " <td>0.627638</td>\n", + " </tr>\n", + " <tr>\n", + " <td>66</td>\n", + " <td>0.006900</td>\n", + " <td>0.076945</td>\n", + " <td>0.555310</td>\n", + " <td>0.683230</td>\n", + " <td>0.259244</td>\n", + " <td>0.809295</td>\n", + " <td>0.587342</td>\n", + " <td>0.600000</td>\n", + " <td>0.777567</td>\n", + " <td>0.734819</td>\n", + " <td>0.625851</td>\n", + " </tr>\n", + " <tr>\n", + " <td>67</td>\n", + " <td>0.006900</td>\n", + " <td>0.079016</td>\n", + " <td>0.537566</td>\n", + " <td>0.693313</td>\n", + " <td>0.269021</td>\n", + " <td>0.809382</td>\n", + " <td>0.594458</td>\n", + " <td>0.600000</td>\n", + " <td>0.767677</td>\n", + " <td>0.735791</td>\n", + " <td>0.625901</td>\n", + " </tr>\n", + " <tr>\n", + " <td>68</td>\n", + " <td>0.006900</td>\n", + " <td>0.076534</td>\n", + " <td>0.545013</td>\n", + " <td>0.688752</td>\n", + " <td>0.255588</td>\n", + " <td>0.803952</td>\n", + " <td>0.614973</td>\n", + " <td>0.606335</td>\n", + " <td>0.759070</td>\n", + " <td>0.742642</td>\n", + " <td>0.627041</td>\n", + " </tr>\n", + " <tr>\n", + " <td>69</td>\n", + " <td>0.006900</td>\n", + " <td>0.079277</td>\n", + " <td>0.540419</td>\n", + " <td>0.705444</td>\n", + " <td>0.253413</td>\n", + " <td>0.804418</td>\n", + " <td>0.597884</td>\n", + " <td>0.634361</td>\n", + " <td>0.751456</td>\n", + " <td>0.732752</td>\n", + " <td>0.627518</td>\n", + " </tr>\n", + " <tr>\n", + " <td>70</td>\n", + " <td>0.006300</td>\n", + " <td>0.078259</td>\n", + " <td>0.547517</td>\n", + " <td>0.696510</td>\n", + " <td>0.283223</td>\n", + " <td>0.811007</td>\n", + " <td>0.579088</td>\n", + " <td>0.612613</td>\n", + " <td>0.753623</td>\n", + " <td>0.729430</td>\n", + " <td>0.626626</td>\n", + " </tr>\n", + " <tr>\n", + " <td>71</td>\n", + " <td>0.006300</td>\n", + " <td>0.078232</td>\n", + " <td>0.546977</td>\n", + " <td>0.703488</td>\n", + " <td>0.248394</td>\n", + " <td>0.807770</td>\n", + " <td>0.598425</td>\n", + " <td>0.649573</td>\n", + " <td>0.752381</td>\n", + " <td>0.737819</td>\n", + " <td>0.630603</td>\n", + " </tr>\n", + " <tr>\n", + " <td>72</td>\n", + " <td>0.006300</td>\n", + " <td>0.077648</td>\n", + " <td>0.554119</td>\n", + " <td>0.699552</td>\n", + " <td>0.256694</td>\n", + " <td>0.806284</td>\n", + " <td>0.598425</td>\n", + " <td>0.603604</td>\n", + " <td>0.753098</td>\n", + " <td>0.742378</td>\n", + " <td>0.626769</td>\n", + " </tr>\n", + " <tr>\n", + " <td>73</td>\n", + " <td>0.006300</td>\n", + " <td>0.077358</td>\n", + " <td>0.557540</td>\n", + " <td>0.701728</td>\n", + " <td>0.284053</td>\n", + " <td>0.805705</td>\n", + " <td>0.605598</td>\n", + " <td>0.607143</td>\n", + " <td>0.753976</td>\n", + " <td>0.737557</td>\n", + " <td>0.631662</td>\n", + " </tr>\n", + " <tr>\n", + " <td>74</td>\n", + " <td>0.006300</td>\n", + " <td>0.077561</td>\n", + " <td>0.564262</td>\n", + " <td>0.702176</td>\n", + " <td>0.278954</td>\n", + " <td>0.806734</td>\n", + " <td>0.591623</td>\n", + " <td>0.609865</td>\n", + " <td>0.757835</td>\n", + " <td>0.740061</td>\n", + " <td>0.631439</td>\n", + " </tr>\n", + " <tr>\n", + " <td>75</td>\n", + " <td>0.006300</td>\n", + " <td>0.078310</td>\n", + " <td>0.565003</td>\n", + " <td>0.701201</td>\n", + " <td>0.263357</td>\n", + " <td>0.803329</td>\n", + " <td>0.589474</td>\n", + " <td>0.600897</td>\n", + " <td>0.755471</td>\n", + " <td>0.742726</td>\n", + " <td>0.627682</td>\n", + " </tr>\n", + " <tr>\n", + " <td>76</td>\n", + " <td>0.006300</td>\n", + " <td>0.078413</td>\n", + " <td>0.567366</td>\n", + " <td>0.703148</td>\n", + " <td>0.270180</td>\n", + " <td>0.802114</td>\n", + " <td>0.576000</td>\n", + " <td>0.607143</td>\n", + " <td>0.757776</td>\n", + " <td>0.746544</td>\n", + " <td>0.628784</td>\n", + " </tr>\n", + " <tr>\n", + " <td>77</td>\n", + " <td>0.005000</td>\n", + " <td>0.078019</td>\n", + " <td>0.564796</td>\n", + " <td>0.705438</td>\n", + " <td>0.273219</td>\n", + " <td>0.805654</td>\n", + " <td>0.596306</td>\n", + " <td>0.616071</td>\n", + " <td>0.757974</td>\n", + " <td>0.740684</td>\n", + " <td>0.632518</td>\n", + " </tr>\n", + " <tr>\n", + " <td>78</td>\n", + " <td>0.005000</td>\n", + " <td>0.077703</td>\n", + " <td>0.566892</td>\n", + " <td>0.707646</td>\n", + " <td>0.275145</td>\n", + " <td>0.805477</td>\n", + " <td>0.587302</td>\n", + " <td>0.616071</td>\n", + " <td>0.756808</td>\n", + " <td>0.742857</td>\n", + " <td>0.632275</td>\n", + " </tr>\n", + " <tr>\n", + " <td>79</td>\n", + " <td>0.005000</td>\n", + " <td>0.078168</td>\n", + " <td>0.569659</td>\n", + " <td>0.704460</td>\n", + " <td>0.265828</td>\n", + " <td>0.806930</td>\n", + " <td>0.577540</td>\n", + " <td>0.603604</td>\n", + " <td>0.755005</td>\n", + " <td>0.750964</td>\n", + " <td>0.629249</td>\n", + " </tr>\n", + " <tr>\n", + " <td>80</td>\n", + " <td>0.005000</td>\n", + " <td>0.078695</td>\n", + " <td>0.568617</td>\n", + " <td>0.707391</td>\n", + " <td>0.269616</td>\n", + " <td>0.808101</td>\n", + " <td>0.579787</td>\n", + " <td>0.628319</td>\n", + " <td>0.755258</td>\n", + " <td>0.743570</td>\n", + " <td>0.632582</td>\n", + " </tr>\n", + " <tr>\n", + " <td>81</td>\n", + " <td>0.005000</td>\n", + " <td>0.079567</td>\n", + " <td>0.563530</td>\n", + " <td>0.704545</td>\n", + " <td>0.258365</td>\n", + " <td>0.808841</td>\n", + " <td>0.582011</td>\n", + " <td>0.622222</td>\n", + " <td>0.756705</td>\n", + " <td>0.743169</td>\n", + " <td>0.629924</td>\n", + " </tr>\n", + " <tr>\n", + " <td>82</td>\n", + " <td>0.005000</td>\n", + " <td>0.079116</td>\n", + " <td>0.566155</td>\n", + " <td>0.704012</td>\n", + " <td>0.264024</td>\n", + " <td>0.809779</td>\n", + " <td>0.579787</td>\n", + " <td>0.622222</td>\n", + " <td>0.760267</td>\n", + " <td>0.738413</td>\n", + " <td>0.630582</td>\n", + " </tr>\n", + " <tr>\n", + " <td>83</td>\n", + " <td>0.005000</td>\n", + " <td>0.079035</td>\n", + " <td>0.570076</td>\n", + " <td>0.703563</td>\n", + " <td>0.266835</td>\n", + " <td>0.808732</td>\n", + " <td>0.587302</td>\n", + " <td>0.634361</td>\n", + " <td>0.759542</td>\n", + " <td>0.739812</td>\n", + " <td>0.633778</td>\n", + " </tr>\n", + " <tr>\n", + " <td>84</td>\n", + " <td>0.005000</td>\n", + " <td>0.079356</td>\n", + " <td>0.568981</td>\n", + " <td>0.704545</td>\n", + " <td>0.266217</td>\n", + " <td>0.807873</td>\n", + " <td>0.587302</td>\n", + " <td>0.660944</td>\n", + " <td>0.761450</td>\n", + " <td>0.738824</td>\n", + " <td>0.637017</td>\n", + " </tr>\n", + " <tr>\n", + " <td>85</td>\n", + " <td>0.004500</td>\n", + " <td>0.079107</td>\n", + " <td>0.569973</td>\n", + " <td>0.707207</td>\n", + " <td>0.259619</td>\n", + " <td>0.807485</td>\n", + " <td>0.591623</td>\n", + " <td>0.655172</td>\n", + " <td>0.764212</td>\n", + " <td>0.740566</td>\n", + " <td>0.636982</td>\n", + " </tr>\n", + " <tr>\n", + " <td>86</td>\n", + " <td>0.004500</td>\n", + " <td>0.079421</td>\n", + " <td>0.568827</td>\n", + " <td>0.705971</td>\n", + " <td>0.258694</td>\n", + " <td>0.809644</td>\n", + " <td>0.585752</td>\n", + " <td>0.652361</td>\n", + " <td>0.765363</td>\n", + " <td>0.737743</td>\n", + " <td>0.635544</td>\n", + " </tr>\n", + " <tr>\n", + " <td>87</td>\n", + " <td>0.004500</td>\n", + " <td>0.079341</td>\n", + " <td>0.567676</td>\n", + " <td>0.707554</td>\n", + " <td>0.257106</td>\n", + " <td>0.811289</td>\n", + " <td>0.578947</td>\n", + " <td>0.649351</td>\n", + " <td>0.762984</td>\n", + " <td>0.737984</td>\n", + " <td>0.634112</td>\n", + " </tr>\n", + " <tr>\n", + " <td>88</td>\n", + " <td>0.004500</td>\n", + " <td>0.079288</td>\n", + " <td>0.567780</td>\n", + " <td>0.706497</td>\n", + " <td>0.261311</td>\n", + " <td>0.811236</td>\n", + " <td>0.582677</td>\n", + " <td>0.646552</td>\n", + " <td>0.764428</td>\n", + " <td>0.739705</td>\n", + " <td>0.635023</td>\n", + " </tr>\n", + " <tr>\n", + " <td>89</td>\n", + " <td>0.004500</td>\n", + " <td>0.079334</td>\n", + " <td>0.567676</td>\n", + " <td>0.708520</td>\n", + " <td>0.270475</td>\n", + " <td>0.811135</td>\n", + " <td>0.579634</td>\n", + " <td>0.646552</td>\n", + " <td>0.762085</td>\n", + " <td>0.739130</td>\n", + " <td>0.635651</td>\n", + " </tr>\n", + " <tr>\n", + " <td>90</td>\n", + " <td>0.004500</td>\n", + " <td>0.079473</td>\n", + " <td>0.567153</td>\n", + " <td>0.708520</td>\n", + " <td>0.267452</td>\n", + " <td>0.810241</td>\n", + " <td>0.582677</td>\n", + " <td>0.652361</td>\n", + " <td>0.763981</td>\n", + " <td>0.738962</td>\n", + " <td>0.636418</td>\n", + " </tr>\n", + " <tr>\n", + " <td>91</td>\n", + " <td>0.004500</td>\n", + " <td>0.079833</td>\n", + " <td>0.567153</td>\n", + " <td>0.710645</td>\n", + " <td>0.264247</td>\n", + " <td>0.811074</td>\n", + " <td>0.580475</td>\n", + " <td>0.652361</td>\n", + " <td>0.763258</td>\n", + " <td>0.737984</td>\n", + " <td>0.635899</td>\n", + " </tr>\n", + " <tr>\n", + " <td>92</td>\n", + " <td>0.004500</td>\n", + " <td>0.080144</td>\n", + " <td>0.567937</td>\n", + " <td>0.709434</td>\n", + " <td>0.260722</td>\n", + " <td>0.810212</td>\n", + " <td>0.576720</td>\n", + " <td>0.652361</td>\n", + " <td>0.764706</td>\n", + " <td>0.739130</td>\n", + " <td>0.635153</td>\n", + " </tr>\n", + " <tr>\n", + " <td>93</td>\n", + " <td>0.004200</td>\n", + " <td>0.080148</td>\n", + " <td>0.567050</td>\n", + " <td>0.706505</td>\n", + " <td>0.261349</td>\n", + " <td>0.810666</td>\n", + " <td>0.576720</td>\n", + " <td>0.652361</td>\n", + " <td>0.764259</td>\n", + " <td>0.740452</td>\n", + " <td>0.634920</td>\n", + " </tr>\n", + " <tr>\n", + " <td>94</td>\n", + " <td>0.004200</td>\n", + " <td>0.080076</td>\n", + " <td>0.567202</td>\n", + " <td>0.704012</td>\n", + " <td>0.263625</td>\n", + " <td>0.811107</td>\n", + " <td>0.576720</td>\n", + " <td>0.643478</td>\n", + " <td>0.763533</td>\n", + " <td>0.737743</td>\n", + " <td>0.633428</td>\n", + " </tr>\n", + " <tr>\n", + " <td>95</td>\n", + " <td>0.004200</td>\n", + " <td>0.080163</td>\n", + " <td>0.568563</td>\n", + " <td>0.702580</td>\n", + " <td>0.260465</td>\n", + " <td>0.810775</td>\n", + " <td>0.585752</td>\n", + " <td>0.631579</td>\n", + " <td>0.762808</td>\n", + " <td>0.737500</td>\n", + " <td>0.632503</td>\n", + " </tr>\n", + " <tr>\n", + " <td>96</td>\n", + " <td>0.004200</td>\n", + " <td>0.080179</td>\n", + " <td>0.568931</td>\n", + " <td>0.703196</td>\n", + " <td>0.259949</td>\n", + " <td>0.812063</td>\n", + " <td>0.583554</td>\n", + " <td>0.625551</td>\n", + " <td>0.762808</td>\n", + " <td>0.736513</td>\n", + " <td>0.631571</td>\n", + " </tr>\n", + " <tr>\n", + " <td>97</td>\n", + " <td>0.004200</td>\n", + " <td>0.080223</td>\n", + " <td>0.568774</td>\n", + " <td>0.703196</td>\n", + " <td>0.260575</td>\n", + " <td>0.812604</td>\n", + " <td>0.587302</td>\n", + " <td>0.631579</td>\n", + " <td>0.763981</td>\n", + " <td>0.737500</td>\n", + " <td>0.633189</td>\n", + " </tr>\n", + " <tr>\n", + " <td>98</td>\n", + " <td>0.004200</td>\n", + " <td>0.080253</td>\n", + " <td>0.568513</td>\n", + " <td>0.702662</td>\n", + " <td>0.261935</td>\n", + " <td>0.812875</td>\n", + " <td>0.591029</td>\n", + " <td>0.631579</td>\n", + " <td>0.762535</td>\n", + " <td>0.738486</td>\n", + " <td>0.633702</td>\n", + " </tr>\n", + " <tr>\n", + " <td>99</td>\n", + " <td>0.004200</td>\n", + " <td>0.080378</td>\n", + " <td>0.567834</td>\n", + " <td>0.702128</td>\n", + " <td>0.261200</td>\n", + " <td>0.812427</td>\n", + " <td>0.585752</td>\n", + " <td>0.631579</td>\n", + " <td>0.763705</td>\n", + " <td>0.739062</td>\n", + " <td>0.632961</td>\n", + " </tr>\n", + " <tr>\n", + " <td>100</td>\n", + " <td>0.004000</td>\n", + " <td>0.080378</td>\n", + " <td>0.567834</td>\n", + " <td>0.703113</td>\n", + " <td>0.261200</td>\n", + " <td>0.812563</td>\n", + " <td>0.585752</td>\n", + " <td>0.631579</td>\n", + " <td>0.763705</td>\n", + " <td>0.739062</td>\n", + " <td>0.633101</td>\n", + " </tr>\n", + " </tbody>\n", + "</table><p>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=1300, training_loss=0.03139937967061997, metrics={'train_runtime': 548.1168, 'train_samples_per_second': 37.583, 'train_steps_per_second': 2.372, 'total_flos': 2691526921113600.0, 'train_loss': 0.03139937967061997, 'epoch': 100.0})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Initialize the callback\n", + "metrics_callback = TrainingMetricsCallback()\n", + "\n", + "def model_init():\n", + " # For reproducibility\n", + " return RobertaForSpanCategorization.from_pretrained(modelId, id2label=id2label, label2id=label2id)\n", + "\n", + "trainer = Trainer(\n", + " model_init=model_init,\n", + " args=training_args,\n", + " train_dataset=tokenized_train_ds,\n", + " eval_dataset=tokenized_val_ds,\n", + " data_collator=data_collator,\n", + " tokenizer=tokenizer,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[metrics_callback]\n", + ")\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6b029310-5257-41d8-a5f7-3bbb2021ebae", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 1000x500 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the training and evaluation metrics\n", + "df = pd.DataFrame({'step': metrics_callback.steps})\n", + "df['macro_f1'] = pd.Series(metrics_callback.macro_f1)\n", + "\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(df['step'], df['macro_f1'], label='Macro F1')\n", + "plt.xlabel('Step')\n", + "plt.ylabel('Metrics')\n", + "plt.legend()\n", + "plt.title('Training and Evaluation Metrics over Steps')\n", + "\n", + "plt.savefig('../../results/BERT/legal_bert.png')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "89c997e6a944bc70", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.model.save_pretrained(\"../../models/Fine-tuned_LegalCamemBERT-base\")" + ] + }, + { + "cell_type": "markdown", + "id": "05f2cf51-2a76-4de6-9b22-af7dda1eb805", + "metadata": {}, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e48a67b7-814d-46d5-a079-fafc5b6adf86", + "metadata": {}, + "outputs": [], + "source": [ + "model = RobertaForSpanCategorization.from_pretrained(\"../../models/Fine-tuned_LegalCamemBERT-base\")\n", + "tokenizer = AutoTokenizer.from_pretrained(modelId)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "4cb75af5-234f-4863-94f8-45a364e1b877", + "metadata": {}, + "outputs": [], + "source": [ + "def get_offsets_and_predicted_tags(example: str, model, tokenizer, threshold=0):\n", + " \"\"\"\n", + " Get prediction of model on example, using tokenizer\n", + " Args:\n", + " - example (str): The input text\n", + " - model: The span categorizer\n", + " - tokenizer: The tokenizer\n", + " - threshold: The threshold to decide whether the token should belong to the label. Default to 0, which corresponds to probability 0.5.\n", + " Returns:\n", + " - List of (token, tags, offset) for each token.\n", + " \"\"\"\n", + " # Tokenize the sentence to retrieve the tokens and offset mappings\n", + " raw_encoded_example = tokenizer(example, return_offsets_mapping=True)\n", + " encoded_example = tokenizer(example, return_tensors=\"pt\")\n", + " \n", + " # Call the model. The output LxK-tensor where L is the number of tokens, K is the number of classes\n", + " out = model(**encoded_example)[\"logits\"][0]\n", + " \n", + " # We assign to each token the classes whose logit is positive\n", + " predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out]\n", + " \n", + " return [{\"token\": token, \"tags\": tag, \"offset\": offset} for (token, tag, offset) \n", + " in zip(tokenizer.batch_decode(raw_encoded_example[\"input_ids\"]), \n", + " predicted_tags, \n", + " raw_encoded_example[\"offset_mapping\"])]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "314d7a28-ef8b-4367-8e6d-79a55951fcec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<s> - []\n", + "afin - []\n", + "de - []\n", + "vérifier - []\n", + "le - []\n", + "kilométrage - []\n", + " - []\n", + ", - []\n", + "pour - []\n", + "les - [3]\n", + "véhicules - [3]\n", + "équipés - [4]\n", + "d - [4]\n", + " - [4]\n", + "' - [4]\n", + "un - [4]\n", + "compteur - [4]\n", + "kilo - [4]\n", + "métrique - [4]\n", + " - []\n", + ", - []\n", + "les - [3]\n", + "informations - [3]\n", + "communiquées - [4]\n", + "lors - [4, 8]\n", + "du - [8]\n", + "précédent - [8]\n", + "contrôle - [8]\n", + "technique - [8]\n", + "sont - [1]\n", + "mises - [1]\n", + "à - [1]\n", + "la - [1]\n", + "disposition - [1]\n", + "des - [1]\n", + "organismes - [1, 2]\n", + "de - [1, 2]\n", + "contrôle - [1, 2]\n", + "technique - [1, 2]\n", + "dès - [4]\n", + "qu - [4]\n", + " - [4]\n", + "' - [4]\n", + "elles - [4]\n", + "sont - [4]\n", + "disponibles - [4]\n", + "par - [4]\n", + "voie - [4]\n", + "électronique - [4]\n", + " - []\n", + ". - []\n", + "</s> - []\n" + ] + } + ], + "source": [ + "example = \"afin de vérifier le kilométrage , pour les véhicules équipés d ' un compteur kilométrique , les informations communiquées lors du précédent contrôle technique sont mises à la disposition des organismes de contrôle technique dès qu ' elles sont disponibles par voie électronique .\"\n", + "for item in get_offsets_and_predicted_tags(example, model, tokenizer):\n", + " print(f\"\"\"{item[\"token\"]:15} - {item[\"tags\"]}\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b0164373-ee70-4d63-a0c2-ac2197b37b22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "afin de vérifier le kilométrage , pour les véhicules équipés d ' un compteur kilométrique , les informations communiquées lors du précédent contrôle technique sont mises à la disposition des organismes de contrôle technique dès qu ' elles sont disponibles par voie électronique .\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'start': 39, 'end': 52, 'tag': 'artifact', 'text': 'les véhicules'},\n", + " {'start': 53,\n", + " 'end': 89,\n", + " 'tag': 'condition',\n", + " 'text': \"équipés d ' un compteur kilométrique\"},\n", + " {'start': 92, 'end': 108, 'tag': 'artifact', 'text': 'les informations'},\n", + " {'start': 109, 'end': 126, 'tag': 'condition', 'text': 'communiquées lors'},\n", + " {'start': 122,\n", + " 'end': 158,\n", + " 'tag': 'time',\n", + " 'text': 'lors du précédent contrôle technique'},\n", + " {'start': 159,\n", + " 'end': 223,\n", + " 'tag': 'action',\n", + " 'text': 'sont mises à la disposition des organismes de contrôle technique'},\n", + " {'start': 191,\n", + " 'end': 223,\n", + " 'tag': 'actor',\n", + " 'text': 'organismes de contrôle technique'},\n", + " {'start': 224,\n", + " 'end': 277,\n", + " 'tag': 'condition',\n", + " 'text': \"dès qu ' elles sont disponibles par voie électronique\"}]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def get_tagged_groups(example: str, model, tokenizer):\n", + " \"\"\"\n", + " Get prediction of model on example, using tokenizer\n", + " Returns:\n", + " - List of spans under offset format {\"start\": ..., \"end\": ..., \"tag\": ...}, sorted by start, end then tag.\n", + " \"\"\"\n", + " offsets_and_tags = get_offsets_and_predicted_tags(example, model, tokenizer)\n", + " predicted_offsets = {l: [] for l in tag2id}\n", + " last_token_tags = []\n", + " for item in offsets_and_tags:\n", + " (start, end), tags = item[\"offset\"], item[\"tags\"]\n", + " \n", + " for label_id in tags:\n", + " tag = id2label[label_id]\n", + " if label_id not in last_token_tags and label2id[f\"{tag}\"] not in last_token_tags:\n", + " predicted_offsets[tag].append({\"start\": start, \"end\": end})\n", + " else:\n", + " predicted_offsets[tag][-1][\"end\"] = end\n", + " \n", + " last_token_tags = tags\n", + " \n", + " flatten_predicted_offsets = [{**v, \"tag\": k, \"text\": example[v[\"start\"]:v[\"end\"]]} \n", + " for k, v_list in predicted_offsets.items() for v in v_list if v[\"end\"] - v[\"start\"] >= 3]\n", + " flatten_predicted_offsets = sorted(flatten_predicted_offsets, \n", + " key = lambda row: (row[\"start\"], row[\"end\"], row[\"tag\"]))\n", + " return flatten_predicted_offsets\n", + "\n", + "print(example)\n", + "get_tagged_groups(example, model, tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54213be5-8221-40a2-ac60-04670d83babc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/modules/camembert/camembert-classifier.ipynb b/modules/camembert/camembert-classifier.ipynb deleted file mode 100644 index 83785f3536a8ce5bf9e8fb74a26f50fd2b3ddc39..0000000000000000000000000000000000000000 --- a/modules/camembert/camembert-classifier.ipynb +++ /dev/null @@ -1,662 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "8ced2e3ca31fb46c", - "metadata": {}, - "source": [ - "# Dataset" - ] - }, - { - "cell_type": "code", - "id": "757a8bf026156e77", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:45:25.410725Z", - "start_time": "2024-06-27T14:45:25.404357Z" - } - }, - "source": [ - "tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, 'time': 8}\n", - "id2tag = {v:k for k, v in tag2id.items()}" - ], - "outputs": [], - "execution_count": 1 - }, - { - "cell_type": "code", - "id": "be3a4c320f9d4a5", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [], - "ExecuteTime": { - "end_time": "2024-06-27T14:45:35.889919Z", - "start_time": "2024-06-27T14:45:35.885841Z" - } - }, - "source": [ - "label2id = {\n", - " 'O': 0,\n", - " **{f'B-{k}': 2*v - 1 for k, v in tag2id.items()},\n", - " **{f'I-{k}': 2*v for k, v in tag2id.items()}\n", - "}\n", - "\n", - "id2label = {v:k for k, v in label2id.items()}" - ], - "outputs": [], - "execution_count": 2 - }, - { - "cell_type": "code", - "id": "2aa2fefac95e7f04", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:45:46.507242Z", - "start_time": "2024-06-27T14:45:38.632569Z" - } - }, - "source": [ - "from datasets import Dataset\n", - "train_ds = Dataset.from_json(\"../../data/annotations.train.jsonlines\")\n", - "val_ds = Dataset.from_json(\"../../data/annotations.eval.jsonlines\")" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "Generating train split: 0 examples [00:00, ? examples/s]" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "c0f11a82572440d9bd1405c2e6ea6d2a" - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Generating train split: 0 examples [00:00, ? examples/s]" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "08366b1257d34fcd97ace0017fd3b395" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 3 - }, - { - "cell_type": "code", - "id": "9e0a21356e7701a1", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:45:52.188551Z", - "start_time": "2024-06-27T14:45:52.185648Z" - } - }, - "source": [ - "modelId = '../../models/CamemBERT-base'" - ], - "outputs": [], - "execution_count": 4 - }, - { - "cell_type": "markdown", - "id": "66e00d5a79a66753", - "metadata": {}, - "source": [ - "# Tokenization" - ] - }, - { - "cell_type": "code", - "id": "e6459259f5ab2d98", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:46:02.189231Z", - "start_time": "2024-06-27T14:45:56.737350Z" - } - }, - "source": [ - "from transformers import AutoTokenizer\n", - "tokenizer = AutoTokenizer.from_pretrained(modelId)" - ], - "outputs": [ - { - "ename": "OSError", - "evalue": "../../models/CamemBERT-base does not appear to have a file named config.json. Checkout 'https://huggingface.co/../../models/CamemBERT-base/None' for available files.", - "output_type": "error", - "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mOSError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[5], line 2\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AutoTokenizer\n\u001B[1;32m----> 2\u001B[0m tokenizer \u001B[38;5;241m=\u001B[39m \u001B[43mAutoTokenizer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodelId\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\transformers\\models\\auto\\tokenization_auto.py:773\u001B[0m, in \u001B[0;36mAutoTokenizer.from_pretrained\u001B[1;34m(cls, pretrained_model_name_or_path, *inputs, **kwargs)\u001B[0m\n\u001B[0;32m 771\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m config_tokenizer_class \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 772\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(config, PretrainedConfig):\n\u001B[1;32m--> 773\u001B[0m config \u001B[38;5;241m=\u001B[39m \u001B[43mAutoConfig\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 774\u001B[0m \u001B[43m \u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrust_remote_code\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtrust_remote_code\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\n\u001B[0;32m 775\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 776\u001B[0m config_tokenizer_class \u001B[38;5;241m=\u001B[39m config\u001B[38;5;241m.\u001B[39mtokenizer_class\n\u001B[0;32m 777\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mhasattr\u001B[39m(config, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mauto_map\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAutoTokenizer\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;129;01min\u001B[39;00m config\u001B[38;5;241m.\u001B[39mauto_map:\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\transformers\\models\\auto\\configuration_auto.py:1100\u001B[0m, in \u001B[0;36mAutoConfig.from_pretrained\u001B[1;34m(cls, pretrained_model_name_or_path, **kwargs)\u001B[0m\n\u001B[0;32m 1097\u001B[0m trust_remote_code \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtrust_remote_code\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n\u001B[0;32m 1098\u001B[0m code_revision \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcode_revision\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n\u001B[1;32m-> 1100\u001B[0m config_dict, unused_kwargs \u001B[38;5;241m=\u001B[39m \u001B[43mPretrainedConfig\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_config_dict\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1101\u001B[0m has_remote_code \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mauto_map\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;129;01min\u001B[39;00m config_dict \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAutoConfig\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;129;01min\u001B[39;00m config_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mauto_map\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n\u001B[0;32m 1102\u001B[0m has_local_code \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmodel_type\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;129;01min\u001B[39;00m config_dict \u001B[38;5;129;01mand\u001B[39;00m config_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmodel_type\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;129;01min\u001B[39;00m CONFIG_MAPPING\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\transformers\\configuration_utils.py:634\u001B[0m, in \u001B[0;36mPretrainedConfig.get_config_dict\u001B[1;34m(cls, pretrained_model_name_or_path, **kwargs)\u001B[0m\n\u001B[0;32m 632\u001B[0m original_kwargs \u001B[38;5;241m=\u001B[39m copy\u001B[38;5;241m.\u001B[39mdeepcopy(kwargs)\n\u001B[0;32m 633\u001B[0m \u001B[38;5;66;03m# Get config dict associated with the base config file\u001B[39;00m\n\u001B[1;32m--> 634\u001B[0m config_dict, kwargs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mcls\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_get_config_dict\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 635\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m_commit_hash\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;129;01min\u001B[39;00m config_dict:\n\u001B[0;32m 636\u001B[0m original_kwargs[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m_commit_hash\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m config_dict[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m_commit_hash\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\transformers\\configuration_utils.py:689\u001B[0m, in \u001B[0;36mPretrainedConfig._get_config_dict\u001B[1;34m(cls, pretrained_model_name_or_path, **kwargs)\u001B[0m\n\u001B[0;32m 685\u001B[0m configuration_file \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m_configuration_file\u001B[39m\u001B[38;5;124m\"\u001B[39m, CONFIG_NAME)\n\u001B[0;32m 687\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 688\u001B[0m \u001B[38;5;66;03m# Load from local folder or from cache or download from model Hub and cache\u001B[39;00m\n\u001B[1;32m--> 689\u001B[0m resolved_config_file \u001B[38;5;241m=\u001B[39m \u001B[43mcached_file\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 690\u001B[0m \u001B[43m \u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 691\u001B[0m \u001B[43m \u001B[49m\u001B[43mconfiguration_file\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 692\u001B[0m \u001B[43m \u001B[49m\u001B[43mcache_dir\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcache_dir\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 693\u001B[0m \u001B[43m \u001B[49m\u001B[43mforce_download\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mforce_download\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 694\u001B[0m \u001B[43m \u001B[49m\u001B[43mproxies\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mproxies\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 695\u001B[0m \u001B[43m \u001B[49m\u001B[43mresume_download\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mresume_download\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 696\u001B[0m \u001B[43m \u001B[49m\u001B[43mlocal_files_only\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlocal_files_only\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 697\u001B[0m \u001B[43m \u001B[49m\u001B[43mtoken\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtoken\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 698\u001B[0m \u001B[43m \u001B[49m\u001B[43muser_agent\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43muser_agent\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 699\u001B[0m \u001B[43m \u001B[49m\u001B[43mrevision\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mrevision\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 700\u001B[0m \u001B[43m \u001B[49m\u001B[43msubfolder\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msubfolder\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 701\u001B[0m \u001B[43m \u001B[49m\u001B[43m_commit_hash\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcommit_hash\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 702\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 703\u001B[0m commit_hash \u001B[38;5;241m=\u001B[39m extract_commit_hash(resolved_config_file, commit_hash)\n\u001B[0;32m 704\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mEnvironmentError\u001B[39;00m:\n\u001B[0;32m 705\u001B[0m \u001B[38;5;66;03m# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to\u001B[39;00m\n\u001B[0;32m 706\u001B[0m \u001B[38;5;66;03m# the original exception.\u001B[39;00m\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\transformers\\utils\\hub.py:356\u001B[0m, in \u001B[0;36mcached_file\u001B[1;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001B[0m\n\u001B[0;32m 354\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39misfile(resolved_file):\n\u001B[0;32m 355\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m _raise_exceptions_for_missing_entries:\n\u001B[1;32m--> 356\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mEnvironmentError\u001B[39;00m(\n\u001B[0;32m 357\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpath_or_repo_id\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m does not appear to have a file named \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfull_filename\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m. Checkout \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 358\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mhttps://huggingface.co/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mpath_or_repo_id\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mrevision\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m for available files.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 359\u001B[0m )\n\u001B[0;32m 360\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 361\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "\u001B[1;31mOSError\u001B[0m: ../../models/CamemBERT-base does not appear to have a file named config.json. Checkout 'https://huggingface.co/../../models/CamemBERT-base/None' for available files." - ] - } - ], - "execution_count": 5 - }, - { - "cell_type": "code", - "id": "8c96680645f077fb", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:46:08.307140Z", - "start_time": "2024-06-27T14:46:08.301542Z" - } - }, - "source": [ - "def get_token_role_in_span(token_start: int, token_end: int, span_start: int, span_end: int):\n", - " \"\"\"\n", - " Check if the token is inside a span.\n", - " Args:\n", - " - token_start, token_end: Start and end offset of the token\n", - " - span_start, span_end: Start and end of the span\n", - " Returns:\n", - " - \"B\" if beginning\n", - " - \"I\" if inner\n", - " - \"O\" if outer\n", - " - \"N\" if not valid token (like <SEP>, <CLS>, <UNK>)\n", - " \"\"\"\n", - " if token_end <= token_start:\n", - " return \"N\"\n", - " if token_start < span_start or token_end > span_end:\n", - " return \"O\"\n", - " if token_start > span_start:\n", - " return \"I\"\n", - " else:\n", - " return \"B\"\n", - "\n", - "MAX_LENGTH = 256\n", - "\n", - "def tokenize_and_adjust_labels(sample):\n", - " \"\"\"\n", - " Args:\n", - " - sample (dict): {\"id\": \"...\", \"text\": \"...\", \"tags\": [{\"start\": ..., \"end\": ..., \"tag\": ...}, ...]\n", - " Returns:\n", - " - The tokenized version of `sample` and the labels of each token.\n", - " \"\"\"\n", - " # Tokenize the text, keep the start and end positions of tokens with `return_offsets_mapping` option\n", - " # Use max_length and truncation to ajust the text length\n", - " tokenized = tokenizer(sample[\"text\"],\n", - " return_offsets_mapping=True,\n", - " padding=\"max_length\",\n", - " max_length=MAX_LENGTH,\n", - " truncation=True)\n", - "\n", - " # We are doing a multilabel classification task at each token, we create a list of size len(label2id)=13 \n", - " # for the 13 labels\n", - " labels = [[0 for _ in label2id.keys()] for _ in range(MAX_LENGTH)]\n", - "\n", - " # Scan all the tokens and spans, assign 1 to the corresponding label if the token lies at the beginning\n", - " # or inside the spans\n", - " for (token_start, token_end), token_labels in zip(tokenized[\"offset_mapping\"], labels):\n", - " #print(token_start, token_end)\n", - " for span in sample[\"tags\"]:\n", - " role = get_token_role_in_span(token_start, token_end, span[\"start\"], span[\"end\"])\n", - " if role == \"B\":\n", - " token_labels[label2id[f\"B-{span['tag']}\"]] = 1\n", - " elif role == \"I\":\n", - " token_labels[label2id[f\"I-{span['tag']}\"]] = 1\n", - "\n", - " return {**tokenized, \"labels\": labels}" - ], - "outputs": [], - "execution_count": 6 - }, - { - "cell_type": "code", - "id": "53310845f13e9d70", - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-27T14:46:13.910535Z", - "start_time": "2024-06-27T14:46:13.689422Z" - } - }, - "source": [ - "tokenized_train_ds = train_ds.map(tokenize_and_adjust_labels, remove_columns=train_ds.column_names)\n", - "tokenized_val_ds = val_ds.map(tokenize_and_adjust_labels, remove_columns=val_ds.column_names)" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "Map: 0%| | 0/206 [00:00<?, ? examples/s]" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "fc1766f4288040c0abe62a22f76de5b3" - } - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "ename": "NameError", - "evalue": "name 'tokenizer' is not defined", - "output_type": "error", - "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mNameError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[7], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m tokenized_train_ds \u001B[38;5;241m=\u001B[39m \u001B[43mtrain_ds\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmap\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtokenize_and_adjust_labels\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mremove_columns\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtrain_ds\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcolumn_names\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 2\u001B[0m tokenized_val_ds \u001B[38;5;241m=\u001B[39m val_ds\u001B[38;5;241m.\u001B[39mmap(tokenize_and_adjust_labels, remove_columns\u001B[38;5;241m=\u001B[39mval_ds\u001B[38;5;241m.\u001B[39mcolumn_names)\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\datasets\\arrow_dataset.py:592\u001B[0m, in \u001B[0;36mtransmit_tasks.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m 590\u001B[0m \u001B[38;5;28mself\u001B[39m: \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDataset\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mself\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 591\u001B[0m \u001B[38;5;66;03m# apply actual function\u001B[39;00m\n\u001B[1;32m--> 592\u001B[0m out: Union[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDataset\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDatasetDict\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 593\u001B[0m datasets: List[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDataset\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mlist\u001B[39m(out\u001B[38;5;241m.\u001B[39mvalues()) \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(out, \u001B[38;5;28mdict\u001B[39m) \u001B[38;5;28;01melse\u001B[39;00m [out]\n\u001B[0;32m 594\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m dataset \u001B[38;5;129;01min\u001B[39;00m datasets:\n\u001B[0;32m 595\u001B[0m \u001B[38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid\u001B[39;00m\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\datasets\\arrow_dataset.py:557\u001B[0m, in \u001B[0;36mtransmit_format.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m 550\u001B[0m self_format \u001B[38;5;241m=\u001B[39m {\n\u001B[0;32m 551\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtype\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_format_type,\n\u001B[0;32m 552\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mformat_kwargs\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_format_kwargs,\n\u001B[0;32m 553\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcolumns\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_format_columns,\n\u001B[0;32m 554\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124moutput_all_columns\u001B[39m\u001B[38;5;124m\"\u001B[39m: \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_output_all_columns,\n\u001B[0;32m 555\u001B[0m }\n\u001B[0;32m 556\u001B[0m \u001B[38;5;66;03m# apply actual function\u001B[39;00m\n\u001B[1;32m--> 557\u001B[0m out: Union[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDataset\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDatasetDict\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 558\u001B[0m datasets: List[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mDataset\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mlist\u001B[39m(out\u001B[38;5;241m.\u001B[39mvalues()) \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(out, \u001B[38;5;28mdict\u001B[39m) \u001B[38;5;28;01melse\u001B[39;00m [out]\n\u001B[0;32m 559\u001B[0m \u001B[38;5;66;03m# re-apply format to the output\u001B[39;00m\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\datasets\\arrow_dataset.py:3093\u001B[0m, in \u001B[0;36mDataset.map\u001B[1;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001B[0m\n\u001B[0;32m 3087\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m transformed_dataset \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 3088\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m hf_tqdm(\n\u001B[0;32m 3089\u001B[0m unit\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m examples\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 3090\u001B[0m total\u001B[38;5;241m=\u001B[39mpbar_total,\n\u001B[0;32m 3091\u001B[0m desc\u001B[38;5;241m=\u001B[39mdesc \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mMap\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 3092\u001B[0m ) \u001B[38;5;28;01mas\u001B[39;00m pbar:\n\u001B[1;32m-> 3093\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43;01mfor\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43mrank\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcontent\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;129;43;01min\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43mDataset\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_map_single\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mdataset_kwargs\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n\u001B[0;32m 3094\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43;01mif\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m:\u001B[49m\n\u001B[0;32m 3095\u001B[0m \u001B[43m \u001B[49m\u001B[43mshards_done\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m+\u001B[39;49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m1\u001B[39;49m\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\datasets\\arrow_dataset.py:3446\u001B[0m, in \u001B[0;36mDataset._map_single\u001B[1;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)\u001B[0m\n\u001B[0;32m 3444\u001B[0m _time \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n\u001B[0;32m 3445\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, example \u001B[38;5;129;01min\u001B[39;00m shard_iterable:\n\u001B[1;32m-> 3446\u001B[0m example \u001B[38;5;241m=\u001B[39m \u001B[43mapply_function_on_filtered_inputs\u001B[49m\u001B[43m(\u001B[49m\u001B[43mexample\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mi\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moffset\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moffset\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 3447\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m update_data:\n\u001B[0;32m 3448\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m i \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n", - "File \u001B[1;32m~\\Desktop\\Sync\\Berger-Levrault\\Codes\\Ala\\Legal Concepts Extraction\\venv\\Lib\\site-packages\\datasets\\arrow_dataset.py:3349\u001B[0m, in \u001B[0;36mDataset._map_single.<locals>.apply_function_on_filtered_inputs\u001B[1;34m(pa_inputs, indices, check_same_num_examples, offset)\u001B[0m\n\u001B[0;32m 3347\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m with_rank:\n\u001B[0;32m 3348\u001B[0m additional_args \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m (rank,)\n\u001B[1;32m-> 3349\u001B[0m processed_inputs \u001B[38;5;241m=\u001B[39m \u001B[43mfunction\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mfn_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43madditional_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mfn_kwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 3350\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(processed_inputs, LazyDict):\n\u001B[0;32m 3351\u001B[0m processed_inputs \u001B[38;5;241m=\u001B[39m {\n\u001B[0;32m 3352\u001B[0m k: v \u001B[38;5;28;01mfor\u001B[39;00m k, v \u001B[38;5;129;01min\u001B[39;00m processed_inputs\u001B[38;5;241m.\u001B[39mdata\u001B[38;5;241m.\u001B[39mitems() \u001B[38;5;28;01mif\u001B[39;00m k \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m processed_inputs\u001B[38;5;241m.\u001B[39mkeys_to_format\n\u001B[0;32m 3353\u001B[0m }\n", - "Cell \u001B[1;32mIn[6], line 33\u001B[0m, in \u001B[0;36mtokenize_and_adjust_labels\u001B[1;34m(sample)\u001B[0m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03mArgs:\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m - sample (dict): {\"id\": \"...\", \"text\": \"...\", \"tags\": [{\"start\": ..., \"end\": ..., \"tag\": ...}, ...]\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03mReturns:\u001B[39;00m\n\u001B[0;32m 29\u001B[0m \u001B[38;5;124;03m - The tokenized version of `sample` and the labels of each token.\u001B[39;00m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 31\u001B[0m \u001B[38;5;66;03m# Tokenize the text, keep the start and end positions of tokens with `return_offsets_mapping` option\u001B[39;00m\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# Use max_length and truncation to ajust the text length\u001B[39;00m\n\u001B[1;32m---> 33\u001B[0m tokenized \u001B[38;5;241m=\u001B[39m \u001B[43mtokenizer\u001B[49m(sample[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtext\u001B[39m\u001B[38;5;124m\"\u001B[39m],\n\u001B[0;32m 34\u001B[0m return_offsets_mapping\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[0;32m 35\u001B[0m padding\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmax_length\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 36\u001B[0m max_length\u001B[38;5;241m=\u001B[39mMAX_LENGTH,\n\u001B[0;32m 37\u001B[0m truncation\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[0;32m 39\u001B[0m \u001B[38;5;66;03m# We are doing a multilabel classification task at each token, we create a list of size len(label2id)=13 \u001B[39;00m\n\u001B[0;32m 40\u001B[0m \u001B[38;5;66;03m# for the 13 labels\u001B[39;00m\n\u001B[0;32m 41\u001B[0m labels \u001B[38;5;241m=\u001B[39m [[\u001B[38;5;241m0\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m label2id\u001B[38;5;241m.\u001B[39mkeys()] \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(MAX_LENGTH)]\n", - "\u001B[1;31mNameError\u001B[0m: name 'tokenizer' is not defined" - ] - } - ], - "execution_count": 7 - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a654180a-a536-43fa-8984-c7b87b419f93", - "metadata": {}, - "source": [ - "sample = tokenized_train_ds[0]\n", - "print(\"--------Token---------|--------Offset----------|--------Labels----------\")\n", - "for token_id, token_labels, offset in zip(sample[\"input_ids\"], sample[\"labels\"], sample[\"offset_mapping\"]):\n", - " # Decode the token_id into text\n", - " token_text = tokenizer.decode(token_id)\n", - " \n", - " # Retrieve all the indices corresponding to the \"1\" at each token, decode them to label name\n", - " labels = [id2label[label_index] for label_index, value in enumerate(token_labels) if value==1]\n", - " \n", - " # Decode those indices into label name\n", - " print(f\" {token_text:20} | {offset} | {labels}\")\n", - " \n", - " # Finish when we meet the end of sentence.\n", - " if token_text == \"</s>\": \n", - " break" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6990d89800dbb440", - "metadata": {}, - "source": [ - "from transformers import DataCollatorWithPadding\n", - "data_collator = DataCollatorWithPadding(tokenizer, padding=True)" - ], - "outputs": [] - }, - { - "cell_type": "markdown", - "id": "668dcf9750404d1c", - "metadata": {}, - "source": [ - "# Adapt the model" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "7bd0cddab7ddb448", - "metadata": {}, - "source": [ - "import numpy as np\n", - "from sklearn.metrics import multilabel_confusion_matrix\n", - "\n", - "n_labels = len(id2label)\n", - "\n", - "def divide(a: int, b: int):\n", - " return a / b if b > 0 else 0\n", - "\n", - "def compute_metrics(p):\n", - " \"\"\"\n", - " Customize the `compute_metrics` of `transformers`\n", - " Args:\n", - " - p (tuple): 2 numpy arrays: predictions and true_labels\n", - " Returns:\n", - " - metrics (dict): f1 score on \n", - " \"\"\"\n", - " # (1)\n", - " predictions, true_labels = p\n", - "\n", - " # (2)\n", - " predicted_labels = np.where(predictions > 0, np.ones(predictions.shape), np.zeros(predictions.shape))\n", - " metrics = {}\n", - "\n", - " # (3)\n", - " cm = multilabel_confusion_matrix(true_labels.reshape(-1, n_labels), predicted_labels.reshape(-1, n_labels))\n", - "\n", - " # (4) \n", - " for label_idx, matrix in enumerate(cm):\n", - " if label_idx == 0:\n", - " continue # We don't care about the label \"O\"\n", - " tp, fp, fn = matrix[1, 1], matrix[0, 1], matrix[1, 0]\n", - " precision = divide(tp, tp + fp)\n", - " recall = divide(tp, tp + fn)\n", - " f1 = divide(2 * precision * recall, precision + recall)\n", - " metrics[f\"f1_{id2label[label_idx]}\"] = f1\n", - "\n", - " # (5)\n", - " macro_f1 = sum(list(metrics.values())) / (n_labels - 1)\n", - " metrics[\"macro_f1\"] = macro_f1\n", - "\n", - " return metrics" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "ea5d16f59728e2b9", - "metadata": {}, - "source": [ - "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n", - "from transformers import RobertaPreTrainedModel, RobertaModel\n", - "from transformers.utils import (\n", - " add_code_sample_docstrings,\n", - " add_start_docstrings,\n", - " add_start_docstrings_to_model_forward,\n", - " logging,\n", - " replace_return_docstrings,\n", - ")\n", - "from transformers.models.roberta.modeling_roberta import (\n", - " ROBERTA_INPUTS_DOCSTRING,\n", - " ROBERTA_START_DOCSTRING,\n", - " RobertaEmbeddings,\n", - ")\n", - "from typing import Optional, Union, Tuple\n", - "from transformers.modeling_outputs import TokenClassifierOutput\n", - "import torch\n", - "from torch import nn\n", - "\n", - "class RobertaForSpanCategorization(RobertaPreTrainedModel):\n", - " _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n", - " _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n", - "\n", - " def __init__(self, config):\n", - " super().__init__(config)\n", - " self.num_labels = config.num_labels\n", - " self.roberta = RobertaModel(config, add_pooling_layer=False)\n", - " classifier_dropout = (\n", - " config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob\n", - " )\n", - " self.dropout = nn.Dropout(classifier_dropout)\n", - " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n", - " # Initialize weights and apply final processing\n", - " self.post_init()\n", - "\n", - " @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n", - " def forward(\n", - " self,\n", - " input_ids: Optional[torch.LongTensor] = None,\n", - " attention_mask: Optional[torch.FloatTensor] = None,\n", - " token_type_ids: Optional[torch.LongTensor] = None,\n", - " position_ids: Optional[torch.LongTensor] = None,\n", - " head_mask: Optional[torch.FloatTensor] = None,\n", - " inputs_embeds: Optional[torch.FloatTensor] = None,\n", - " labels: Optional[torch.LongTensor] = None,\n", - " output_attentions: Optional[bool] = None,\n", - " output_hidden_states: Optional[bool] = None,\n", - " return_dict: Optional[bool] = None,\n", - " ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:\n", - " r\"\"\"\n", - " labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n", - " Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.\n", - " \"\"\"\n", - " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", - " outputs = self.roberta(\n", - " input_ids,\n", - " attention_mask=attention_mask,\n", - " token_type_ids=token_type_ids,\n", - " position_ids=position_ids,\n", - " head_mask=head_mask,\n", - " inputs_embeds=inputs_embeds,\n", - " output_attentions=output_attentions,\n", - " output_hidden_states=output_hidden_states,\n", - " return_dict=return_dict,\n", - " )\n", - " sequence_output = outputs[0]\n", - " sequence_output = self.dropout(sequence_output)\n", - " logits = self.classifier(sequence_output)\n", - "\n", - " loss = None\n", - " if labels is not None:\n", - " loss_fct = nn.BCEWithLogitsLoss()\n", - " loss = loss_fct(logits, labels.float())\n", - " if not return_dict:\n", - " output = (logits,) + outputs[2:]\n", - " return ((loss,) + output) if loss is not None else output\n", - " return TokenClassifierOutput(\n", - " loss=loss,\n", - " logits=logits,\n", - " hidden_states=outputs.hidden_states,\n", - " attentions=outputs.attentions,\n", - " )" - ], - "outputs": [] - }, - { - "cell_type": "markdown", - "id": "77f4fc68394aa754", - "metadata": {}, - "source": [ - "# Fine-tuning" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "79161ed938cad895", - "metadata": {}, - "source": [ - "training_args = TrainingArguments(\n", - " output_dir=\"./models/fine_tune_bert_output_span_cat\",\n", - " evaluation_strategy=\"epoch\",\n", - " learning_rate=2.5e-4,\n", - " per_device_train_batch_size=16,\n", - " per_device_eval_batch_size=16,\n", - " num_train_epochs=100,\n", - " weight_decay=0.01,\n", - " logging_steps = 100,\n", - " save_strategy='epoch',\n", - " save_total_limit=2,\n", - " load_best_model_at_end=True,\n", - " metric_for_best_model='macro_f1',\n", - " log_level='critical',\n", - " seed=12345\n", - ")" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "931792b554582a9f", - "metadata": {}, - "source": [ - "def model_init():\n", - " # For reproducibility\n", - " return RobertaForSpanCategorization.from_pretrained(modelId, id2label=id2label, label2id=label2id)\n", - "\n", - "trainer = Trainer(\n", - " model_init=model_init,\n", - " args=training_args,\n", - " train_dataset=tokenized_train_ds,\n", - " eval_dataset=tokenized_val_ds,\n", - " data_collator=data_collator,\n", - " tokenizer=tokenizer,\n", - " compute_metrics=compute_metrics\n", - ")\n", - "trainer.train()" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "89c997e6a944bc70", - "metadata": {}, - "source": [ - "trainer.model.save_pretrained(\"../../models/Fine-tuned_CamemBERT-base\")" - ], - "outputs": [] - }, - { - "cell_type": "markdown", - "id": "05f2cf51-2a76-4de6-9b22-af7dda1eb805", - "metadata": {}, - "source": [ - "# Inference" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e48a67b7-814d-46d5-a079-fafc5b6adf86", - "metadata": {}, - "source": [ - "model = RobertaForSpanCategorization.from_pretrained(\"../../models/Fine-tuned_CamemBERT-base\")\n", - "tokenizer = AutoTokenizer.from_pretrained(modelId)" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "4cb75af5-234f-4863-94f8-45a364e1b877", - "metadata": {}, - "source": [ - "def get_offsets_and_predicted_tags(example: str, model, tokenizer, threshold=0):\n", - " \"\"\"\n", - " Get prediction of model on example, using tokenizer\n", - " Args:\n", - " - example (str): The input text\n", - " - model: The span categorizer\n", - " - tokenizer: The tokenizer\n", - " - threshold: The threshold to decide whether the token should belong to the label. Default to 0, which corresponds to probability 0.5.\n", - " Returns:\n", - " - List of (token, tags, offset) for each token.\n", - " \"\"\"\n", - " # Tokenize the sentence to retrieve the tokens and offset mappings\n", - " raw_encoded_example = tokenizer(example, return_offsets_mapping=True)\n", - " encoded_example = tokenizer(example, return_tensors=\"pt\")\n", - " \n", - " # Call the model. The output LxK-tensor where L is the number of tokens, K is the number of classes\n", - " out = model(**encoded_example)[\"logits\"][0]\n", - " \n", - " # We assign to each token the classes whose logit is positive\n", - " predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out]\n", - " \n", - " return [{\"token\": token, \"tags\": tag, \"offset\": offset} for (token, tag, offset) \n", - " in zip(tokenizer.batch_decode(raw_encoded_example[\"input_ids\"]), \n", - " predicted_tags, \n", - " raw_encoded_example[\"offset_mapping\"])]" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "314d7a28-ef8b-4367-8e6d-79a55951fcec", - "metadata": {}, - "source": [ - "example = \"afin de vérifier le kilométrage , pour les véhicules équipés d ' un compteur kilométrique , les informations communiquées lors du précédent contrôle technique sont mises à la disposition des organismes de contrôle technique dès qu ' elles sont disponibles par voie électronique .\"\n", - "for item in get_offsets_and_predicted_tags(example, model, tokenizer):\n", - " print(f\"\"\"{item[\"token\"]:15} - {item[\"tags\"]}\"\"\")" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "b0164373-ee70-4d63-a0c2-ac2197b37b22", - "metadata": {}, - "source": [ - "def get_tagged_groups(example: str, model, tokenizer):\n", - " \"\"\"\n", - " Get prediction of model on example, using tokenizer\n", - " Returns:\n", - " - List of spans under offset format {\"start\": ..., \"end\": ..., \"tag\": ...}, sorted by start, end then tag.\n", - " \"\"\"\n", - " offsets_and_tags = get_offsets_and_predicted_tags(example, model, tokenizer)\n", - " predicted_offsets = {l: [] for l in tag2id}\n", - " last_token_tags = []\n", - " for item in offsets_and_tags:\n", - " (start, end), tags = item[\"offset\"], item[\"tags\"]\n", - " \n", - " for label_id in tags:\n", - " label = id2label[label_id]\n", - " tag = label[2:] # \"I-PER\" => \"PER\"\n", - " if label.startswith(\"B-\"):\n", - " predicted_offsets[tag].append({\"start\": start, \"end\": end})\n", - " elif label.startswith(\"I-\"):\n", - " # If \"B-\" and \"I-\" both appear in the same tag, ignore as we already processed it\n", - " if label2id[f\"B-{tag}\"] in tags:\n", - " continue\n", - " \n", - " if label_id not in last_token_tags and label2id[f\"B-{tag}\"] not in last_token_tags:\n", - " predicted_offsets[tag].append({\"start\": start, \"end\": end})\n", - " else:\n", - " predicted_offsets[tag][-1][\"end\"] = end\n", - " \n", - " last_token_tags = tags\n", - " \n", - " flatten_predicted_offsets = [{**v, \"tag\": k, \"text\": example[v[\"start\"]:v[\"end\"]]} \n", - " for k, v_list in predicted_offsets.items() for v in v_list if v[\"end\"] - v[\"start\"] >= 3]\n", - " flatten_predicted_offsets = sorted(flatten_predicted_offsets, \n", - " key = lambda row: (row[\"start\"], row[\"end\"], row[\"tag\"]))\n", - " return flatten_predicted_offsets\n", - "\n", - "print(example)\n", - "get_tagged_groups(example, model, tokenizer)" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "54213be5-8221-40a2-ac60-04670d83babc", - "metadata": {}, - "source": [], - "outputs": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/results/BERT/legal_bert.png b/results/BERT/legal_bert.png new file mode 100644 index 0000000000000000000000000000000000000000..d4b4efe0501748329313152190ee4ae114d41077 Binary files /dev/null and b/results/BERT/legal_bert.png differ diff --git a/results/Hybridation/performance.png b/results/BERT/performance.png similarity index 100% rename from results/Hybridation/performance.png rename to results/BERT/performance.png diff --git a/temp.ipynb b/temp.ipynb deleted file mode 100644 index 3556b7d5bd3c4abfcd82e261a6f1c05d624cd8e5..0000000000000000000000000000000000000000 --- a/temp.ipynb +++ /dev/null @@ -1,125 +0,0 @@ -{ - "cells": [ - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-06-11T09:21:56.804294Z", - "start_time": "2024-06-11T09:20:21.553571Z" - } - }, - "cell_type": "code", - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "# Données d'exemple\n", - "categories = ['Cat1', 'Cat2', 'Cat3', 'Cat4', 'Cat5'] # Exemples de catégories\n", - "metrics = ['Rappel', 'Précision', 'F-mesure']\n", - "n_metrics = len(metrics)\n", - "\n", - "# Générer des performances aléatoires pour le système 1\n", - "performance_système_1 = np.random.uniform(0.6, 0.8, size=(len(categories), n_metrics))\n", - "\n", - "# Générer des améliorations ou dégradations aléatoires\n", - "amélioration_dégradation = np.random.uniform(-0.1, 0.1, size=(len(categories), n_metrics))\n", - "\n", - "# Calculer les nouvelles performances\n", - "performance_système_2 = performance_système_1 + amélioration_dégradation\n", - "\n", - "# Créer le barplot\n", - "fig, ax = plt.subplots(figsize=(15, 8))\n", - "\n", - "bar_width = 0.25\n", - "index = np.arange(len(categories))\n", - "\n", - "# Couleurs pour les barres en fonction de l'amélioration ou de la dégradation\n", - "colors = [['green' if diff > 0 else 'red' for diff in cat_diff] for cat_diff in amélioration_dégradation]\n", - "\n", - "# Tracer les barres pour chaque métrique\n", - "for i, metric in enumerate(metrics):\n", - " bars1 = ax.bar(index + i * bar_width, performance_système_1[:, i], bar_width, label=f'{metric}', alpha=0.7, color='lightgray')\n", - " bars2 = ax.bar(index + i * bar_width, amélioration_dégradation[:, i], bar_width, bottom=performance_système_1[:, i], color=[colors[j][i] for j in range(len(categories))], alpha=0.7)\n", - "\n", - " # Ajouter les valeurs sur les barres du système 1\n", - " for bar1, perf1 in zip(bars1, performance_système_1[:, i]):\n", - " height = bar1.get_height()\n", - " ax.annotate(f'{perf1:.2f}',\n", - " xy=(bar1.get_x() + bar1.get_width() / 2, height),\n", - " xytext=(0, 3), # 3 points de décalage vertical\n", - " textcoords=\"offset points\",\n", - " ha='center', va='bottom', color='blue')\n", - "\n", - " # Ajouter les valeurs sur les barres du système 2\n", - " for bar1, bar2, perf1, diff in zip(bars1, bars2, performance_système_1[:, i], amélioration_dégradation[:, i]):\n", - " height = bar1.get_height() + bar2.get_height()\n", - " ax.annotate(f'{perf1 + diff:.2f}',\n", - " xy=(bar2.get_x() + bar2.get_width() / 2, height),\n", - " xytext=(0, 3), # 3 points de décalage vertical\n", - " textcoords=\"offset points\",\n", - " ha='center', va='bottom', color='black')\n", - "\n", - "# Ajouter des labels et une légende\n", - "ax.set_xlabel('Catégories')\n", - "ax.set_ylabel('Performance')\n", - "ax.set_title('Performance du Système avec Améliorations/Dégradations')\n", - "ax.set_xticks(index + bar_width)\n", - "ax.set_xticklabels(categories)\n", - "ax.legend()\n", - "\n", - "# Afficher le graphique\n", - "plt.show()\n" - ], - "id": "c1616f9091e76aaf", - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[12], line 16\u001B[0m\n\u001B[0;32m 13\u001B[0m amélioration_dégradation \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mrandom\u001B[38;5;241m.\u001B[39muniform(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m0.1\u001B[39m, \u001B[38;5;241m0.1\u001B[39m, size\u001B[38;5;241m=\u001B[39m(\u001B[38;5;28mlen\u001B[39m(categories), n_metrics))\n\u001B[0;32m 15\u001B[0m \u001B[38;5;66;03m# Calculer les nouvelles performances\u001B[39;00m\n\u001B[1;32m---> 16\u001B[0m performance_système_2 \u001B[38;5;241m=\u001B[39m \u001B[43mperformance_système_1\u001B[49m \u001B[38;5;241m+\u001B[39m amélioration_dégradation\n\u001B[0;32m 18\u001B[0m \u001B[38;5;66;03m# Créer le barplot\u001B[39;00m\n\u001B[0;32m 19\u001B[0m fig, ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39msubplots(figsize\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m15\u001B[39m, \u001B[38;5;241m8\u001B[39m))\n", - "File \u001B[1;32m_pydevd_bundle\\\\pydevd_cython_win32_311_64.pyx:1187\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_win32_311_64.SafeCallWrapper.__call__\u001B[1;34m()\u001B[0m\n", - "File \u001B[1;32m_pydevd_bundle\\\\pydevd_cython_win32_311_64.pyx:627\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_win32_311_64.PyDBFrame.trace_dispatch\u001B[1;34m()\u001B[0m\n", - "File \u001B[1;32m_pydevd_bundle\\\\pydevd_cython_win32_311_64.pyx:937\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_win32_311_64.PyDBFrame.trace_dispatch\u001B[1;34m()\u001B[0m\n", - "File \u001B[1;32m_pydevd_bundle\\\\pydevd_cython_win32_311_64.pyx:928\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_win32_311_64.PyDBFrame.trace_dispatch\u001B[1;34m()\u001B[0m\n", - "File \u001B[1;32m_pydevd_bundle\\\\pydevd_cython_win32_311_64.pyx:585\u001B[0m, in \u001B[0;36m_pydevd_bundle.pydevd_cython_win32_311_64.PyDBFrame.do_wait_suspend\u001B[1;34m()\u001B[0m\n", - "File \u001B[1;32mC:\\Program Files\\JetBrains\\DataSpell 2023.1.2\\plugins\\python-ce\\helpers\\pydev\\pydevd.py:1185\u001B[0m, in \u001B[0;36mPyDB.do_wait_suspend\u001B[1;34m(self, thread, frame, event, arg, send_suspend_message, is_unhandled_exception)\u001B[0m\n\u001B[0;32m 1182\u001B[0m from_this_thread\u001B[38;5;241m.\u001B[39mappend(frame_id)\n\u001B[0;32m 1184\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_threads_suspended_single_notification\u001B[38;5;241m.\u001B[39mnotify_thread_suspended(thread_id, stop_reason):\n\u001B[1;32m-> 1185\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_do_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msuspend_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfrom_this_thread\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32mC:\\Program Files\\JetBrains\\DataSpell 2023.1.2\\plugins\\python-ce\\helpers\\pydev\\pydevd.py:1200\u001B[0m, in \u001B[0;36mPyDB._do_wait_suspend\u001B[1;34m(self, thread, frame, event, arg, suspend_type, from_this_thread)\u001B[0m\n\u001B[0;32m 1197\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_mpl_hook()\n\u001B[0;32m 1199\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mprocess_internal_commands()\n\u001B[1;32m-> 1200\u001B[0m \u001B[43mtime\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msleep\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0.01\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1202\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcancel_async_evaluation(get_current_thread_id(thread), \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mid\u001B[39m(frame)))\n\u001B[0;32m 1204\u001B[0m \u001B[38;5;66;03m# process any stepping instructions\u001B[39;00m\n", - "\u001B[1;31mKeyboardInterrupt\u001B[0m: " - ] - } - ], - "execution_count": 12 - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": " ", - "id": "c7044d0f9235b5c" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}