Skip to content
Snippets Groups Projects
Commit 03959833 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change var to global

parent 7e74d72e
Branches
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ import trainer_pb2_grpc ...@@ -21,7 +21,7 @@ import trainer_pb2_grpc
is_busy = False is_busy = False
MAX_LENGTH = 256 MAX_LENGTH = 256
tag2id = id2tag = label2id = id2label = tokenizer = n_labels = fondation_model_id = None global_tag2id = global_id2tag = global_label2id = global_id2label = global_tokenizer = global_n_labels = global_fondation_model_id = None
class TrainerServicer(trainer_pb2_grpc.TrainerServicer): class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
...@@ -56,27 +56,27 @@ def serve(): ...@@ -56,27 +56,27 @@ def serve():
def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token): def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token):
fondation_model_id = fondation_model_id global_fondation_model_id = fondation_model_id
tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, global_tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7,
'time': 8} 'time': 8}
id2tag = {v: k for k, v in tag2id.items()} global_id2tag = {v: k for k, v in global_tag2id.items()}
label2id = { global_label2id = {
'O': 0, 'O': 0,
**{f'{k}': v for k, v in tag2id.items()} **{f'{k}': v for k, v in global_tag2id.items()}
} }
id2label = {v: k for k, v in label2id.items()} global_id2label = {v: k for k, v in global_label2id.items()}
train_ds = Dataset.from_list(training_data) train_ds = Dataset.from_list(training_data)
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(fondation_model_id) tokenizer = AutoTokenizer.from_pretrained(global_fondation_model_id)
print("post load tokenizer")
tokenized_train_ds = train_ds.map(tokenize_and_adjust_labels, remove_columns=train_ds.column_names) tokenized_train_ds = train_ds.map(tokenize_and_adjust_labels, remove_columns=train_ds.column_names)
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, padding=True) data_collator = DataCollatorWithPadding(tokenizer, padding=True)
n_labels = len(id2label) n_labels = len(global_id2label)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./models/fine_tune_bert_output_span_cat", output_dir="./models/fine_tune_bert_output_span_cat",
...@@ -113,7 +113,7 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug ...@@ -113,7 +113,7 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug
def model_init(): def model_init():
return RobertaForSpanCategorization.from_pretrained(fondation_model_id, id2label=id2label, label2id=label2id) return RobertaForSpanCategorization.from_pretrained(global_fondation_model_id, id2label=global_id2label, label2id=global_label2id)
def get_token_role_in_span(token_start: int, token_end: int, span_start: int, span_end: int): def get_token_role_in_span(token_start: int, token_end: int, span_start: int, span_end: int):
...@@ -126,19 +126,19 @@ def get_token_role_in_span(token_start: int, token_end: int, span_start: int, sp ...@@ -126,19 +126,19 @@ def get_token_role_in_span(token_start: int, token_end: int, span_start: int, sp
def tokenize_and_adjust_labels(sample): def tokenize_and_adjust_labels(sample):
tokenized = tokenizer(sample["text"], tokenized = global_tokenizer(sample["text"],
return_offsets_mapping=True, return_offsets_mapping=True,
padding="max_length", padding="max_length",
max_length=MAX_LENGTH, max_length=MAX_LENGTH,
truncation=True) truncation=True)
labels = [[0 for _ in label2id.keys()] for _ in range(MAX_LENGTH)] labels = [[0 for _ in global_label2id.keys()] for _ in range(MAX_LENGTH)]
for (token_start, token_end), token_labels in zip(tokenized["offset_mapping"], labels): for (token_start, token_end), token_labels in zip(tokenized["offset_mapping"], labels):
for span in sample["tags"]: for span in sample["tags"]:
role = get_token_role_in_span(token_start, token_end, span["start"], span["end"]) role = get_token_role_in_span(token_start, token_end, span["start"], span["end"])
if role == "I": if role == "I":
token_labels[label2id[f"{span['tag']}"]] = 1 token_labels[global_label2id[f"{span['tag']}"]] = 1
return {**tokenized, "labels": labels} return {**tokenized, "labels": labels}
...@@ -153,7 +153,7 @@ def compute_metrics(p): ...@@ -153,7 +153,7 @@ def compute_metrics(p):
predicted_labels = np.where(predictions > 0, np.ones(predictions.shape), np.zeros(predictions.shape)) predicted_labels = np.where(predictions > 0, np.ones(predictions.shape), np.zeros(predictions.shape))
metrics = {} metrics = {}
cm = multilabel_confusion_matrix(true_labels.reshape(-1, n_labels), predicted_labels.reshape(-1, n_labels)) cm = multilabel_confusion_matrix(true_labels.reshape(-1, global_n_labels), predicted_labels.reshape(-1, global_n_labels))
for label_idx, matrix in enumerate(cm): for label_idx, matrix in enumerate(cm):
if label_idx == 0: if label_idx == 0:
...@@ -162,9 +162,9 @@ def compute_metrics(p): ...@@ -162,9 +162,9 @@ def compute_metrics(p):
precision = divide(tp, tp + fp) precision = divide(tp, tp + fp)
recall = divide(tp, tp + fn) recall = divide(tp, tp + fn)
f1 = divide(2 * precision * recall, precision + recall) f1 = divide(2 * precision * recall, precision + recall)
metrics[f"recall_{id2label[label_idx]}"] = recall metrics[f"recall_{global_id2label[label_idx]}"] = recall
metrics[f"precision_{id2label[label_idx]}"] = precision metrics[f"precision_{global_id2label[label_idx]}"] = precision
metrics[f"f1_{id2label[label_idx]}"] = f1 metrics[f"f1_{global_id2label[label_idx]}"] = f1
f1_values = {k: v for k, v in metrics.items() if k.startswith('f1_')} f1_values = {k: v for k, v in metrics.items() if k.startswith('f1_')}
macro_f1 = sum(f1_values.values()) / len(f1_values) macro_f1 = sum(f1_values.values()) / len(f1_values)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment