diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index 7e7919e5949c3f01b46074ed97366668a8a1914f..49d8941ddf46e89518f06a2075f61c548206b72c 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -21,7 +21,7 @@ import trainer_pb2_grpc is_busy = False MAX_LENGTH = 256 -tag2id = id2tag = label2id = id2label = tokenizer = n_labels = fondation_model_id = None +global_tag2id = global_id2tag = global_label2id = global_id2label = global_tokenizer = global_n_labels = global_fondation_model_id = None class TrainerServicer(trainer_pb2_grpc.TrainerServicer): @@ -56,27 +56,27 @@ def serve(): def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token): - fondation_model_id = fondation_model_id - tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, + global_fondation_model_id = fondation_model_id + global_tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, 'time': 8} - id2tag = {v: k for k, v in tag2id.items()} - label2id = { + global_id2tag = {v: k for k, v in global_tag2id.items()} + global_label2id = { 'O': 0, - **{f'{k}': v for k, v in tag2id.items()} + **{f'{k}': v for k, v in global_tag2id.items()} } - id2label = {v: k for k, v in label2id.items()} + global_id2label = {v: k for k, v in global_label2id.items()} train_ds = Dataset.from_list(training_data) from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(fondation_model_id) - + tokenizer = AutoTokenizer.from_pretrained(global_fondation_model_id) + print("post load tokenizer") tokenized_train_ds = train_ds.map(tokenize_and_adjust_labels, remove_columns=train_ds.column_names) from transformers import DataCollatorWithPadding data_collator = DataCollatorWithPadding(tokenizer, padding=True) - n_labels = len(id2label) + n_labels = len(global_id2label) training_args = TrainingArguments( output_dir="./models/fine_tune_bert_output_span_cat", @@ -113,7 +113,7 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug def model_init(): - return RobertaForSpanCategorization.from_pretrained(fondation_model_id, id2label=id2label, label2id=label2id) + return RobertaForSpanCategorization.from_pretrained(global_fondation_model_id, id2label=global_id2label, label2id=global_label2id) def get_token_role_in_span(token_start: int, token_end: int, span_start: int, span_end: int): @@ -126,19 +126,19 @@ def get_token_role_in_span(token_start: int, token_end: int, span_start: int, sp def tokenize_and_adjust_labels(sample): - tokenized = tokenizer(sample["text"], - return_offsets_mapping=True, - padding="max_length", - max_length=MAX_LENGTH, - truncation=True) + tokenized = global_tokenizer(sample["text"], + return_offsets_mapping=True, + padding="max_length", + max_length=MAX_LENGTH, + truncation=True) - labels = [[0 for _ in label2id.keys()] for _ in range(MAX_LENGTH)] + labels = [[0 for _ in global_label2id.keys()] for _ in range(MAX_LENGTH)] for (token_start, token_end), token_labels in zip(tokenized["offset_mapping"], labels): for span in sample["tags"]: role = get_token_role_in_span(token_start, token_end, span["start"], span["end"]) if role == "I": - token_labels[label2id[f"{span['tag']}"]] = 1 + token_labels[global_label2id[f"{span['tag']}"]] = 1 return {**tokenized, "labels": labels} @@ -153,7 +153,7 @@ def compute_metrics(p): predicted_labels = np.where(predictions > 0, np.ones(predictions.shape), np.zeros(predictions.shape)) metrics = {} - cm = multilabel_confusion_matrix(true_labels.reshape(-1, n_labels), predicted_labels.reshape(-1, n_labels)) + cm = multilabel_confusion_matrix(true_labels.reshape(-1, global_n_labels), predicted_labels.reshape(-1, global_n_labels)) for label_idx, matrix in enumerate(cm): if label_idx == 0: @@ -162,9 +162,9 @@ def compute_metrics(p): precision = divide(tp, tp + fp) recall = divide(tp, tp + fn) f1 = divide(2 * precision * recall, precision + recall) - metrics[f"recall_{id2label[label_idx]}"] = recall - metrics[f"precision_{id2label[label_idx]}"] = precision - metrics[f"f1_{id2label[label_idx]}"] = f1 + metrics[f"recall_{global_id2label[label_idx]}"] = recall + metrics[f"precision_{global_id2label[label_idx]}"] = precision + metrics[f"f1_{global_id2label[label_idx]}"] = f1 f1_values = {k: v for k, v in metrics.items() if k.startswith('f1_')} macro_f1 = sum(f1_values.values()) / len(f1_values)