From cbc4c653975145098028fb875673f66c2cced994 Mon Sep 17 00:00:00 2001 From: "Julien B." <xm9q8f80@jlnbrtn.me> Date: Fri, 23 Aug 2024 18:55:14 +0200 Subject: [PATCH] fix(trainer): change trainer function --- api/internal_services/background_worker.py | 15 ++------------- api/internal_services/trainer.py | 19 +++++++++++++++++++ microservices/trainer/trainer.py | 15 +++++++++------ 3 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 api/internal_services/trainer.py diff --git a/api/internal_services/background_worker.py b/api/internal_services/background_worker.py index 8ed9d8a..510fa70 100644 --- a/api/internal_services/background_worker.py +++ b/api/internal_services/background_worker.py @@ -4,6 +4,7 @@ from threading import Thread import grpc from api.internal_services.logger import logger +from api.internal_services.trainer import start_training from api.models.Job import JobType from api.protos.trainer import trainer_pb2_grpc, trainer_pb2 @@ -45,19 +46,7 @@ def process_trainer(): if job is None: break - with grpc.insecure_channel(job.job_data['server_url']) as channel: - stub = trainer_pb2_grpc.TrainerStub(channel) - request = trainer_pb2.TrainingInput( - training_data=[], - fondation_model_id=job.job_data['fondation_model_id'], - finetuned_repo_name=job.job_data['finetuned_repo_name'], - huggingface_token=job.job_data['huggingface_token'], - ) - logger.debug(request) - responses = stub.StartTraining(request) - for response in responses: - logger.debug(f"gRPC message : {response.status}") - logger.debug(f"fin de la connexion gRPC") + start_training(job) logger.info(f"Ending of the job {job.job_id}") trainer_queue.task_done() diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py new file mode 100644 index 0000000..3a65500 --- /dev/null +++ b/api/internal_services/trainer.py @@ -0,0 +1,19 @@ +import grpc +from api.internal_services.logger import logger +from api.protos.trainer import trainer_pb2_grpc, trainer_pb2 + + +def start_training(job): + with grpc.insecure_channel(job.job_data['server_url']) as channel: + stub = trainer_pb2_grpc.TrainerStub(channel) + request = trainer_pb2.TrainingInput( + training_data=[], + fondation_model_id=job.job_data['fondation_model_id'], + finetuned_repo_name=job.job_data['finetuned_repo_name'], + huggingface_token=job.job_data['huggingface_token'], + ) + logger.debug(request) + responses = stub.StartTraining(request) + for response in responses: + logger.debug(f"gRPC message : {response.status}") + logger.debug(f"fin de la connexion gRPC") \ No newline at end of file diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index 0031470..5a0c31e 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -30,12 +30,15 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): if not is_busy: is_busy = True print(f"incoming request : {request}") - training_process( - training_data=request.training_data, - fondation_model_id=request.fondation_model_id, - finetuned_repo_name=request.finetuned_repo_name, - huggingface_token=request.huggingface_token, - ) + try: + training_process( + training_data=request.training_data, + fondation_model_id=request.fondation_model_id, + finetuned_repo_name=request.finetuned_repo_name, + huggingface_token=request.huggingface_token, + ) + except Exception as e: + print(f"Error : {e}") is_busy = False else: print(f"gRPC server is already busy") -- GitLab