From be18ab98ec4b5707dfc083da5c384f9948be0fe5 Mon Sep 17 00:00:00 2001 From: "Julien B." <xm9q8f80@jlnbrtn.me> Date: Sun, 25 Aug 2024 01:12:23 +0200 Subject: [PATCH] fix(trainer): change trainer function --- api/internal_services/trainer.py | 6 ++---- microservices/trainer/trainer.py | 5 +++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py index 575756a..9b4ab7e 100644 --- a/api/internal_services/trainer.py +++ b/api/internal_services/trainer.py @@ -1,5 +1,4 @@ import json - import grpc from api.internal_services import neo4j from api.internal_services.logger import logger @@ -51,9 +50,8 @@ def start_training(job): finetuned_repo_name=job.job_data['finetuned_repo_name'], huggingface_token=job.job_data['huggingface_token'], ) - responses = stub.StartTraining(request) - for response in responses: - logger.debug(f"Incoming gRPC message : {response.status}") + response = stub.StartTraining(request) + logger.debug(f"Incoming gRPC message : {response.status}") logger.debug(f"fin de la connexion gRPC") # passer toutes les origines des concepts en BERT diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index c532800..1b4c051 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): except Exception as e: print(f"Error : {e}") is_busy = False + torch.cuda.empty_cache() else: print(f"gRPC server is already busy") return trainer_pb2.TrainingEvent(status="Training ended successfully !") @@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token) tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token) + torch.cuda.empty_cache() + + del trainer, tokenizer, train_ds + if __name__ == '__main__': serve() -- GitLab