From cbc4c653975145098028fb875673f66c2cced994 Mon Sep 17 00:00:00 2001
From: "Julien B." <xm9q8f80@jlnbrtn.me>
Date: Fri, 23 Aug 2024 18:55:14 +0200
Subject: [PATCH] fix(trainer): change trainer function

---
 api/internal_services/background_worker.py | 15 ++-------------
 api/internal_services/trainer.py           | 19 +++++++++++++++++++
 microservices/trainer/trainer.py           | 15 +++++++++------
 3 files changed, 30 insertions(+), 19 deletions(-)
 create mode 100644 api/internal_services/trainer.py

diff --git a/api/internal_services/background_worker.py b/api/internal_services/background_worker.py
index 8ed9d8a..510fa70 100644
--- a/api/internal_services/background_worker.py
+++ b/api/internal_services/background_worker.py
@@ -4,6 +4,7 @@ from threading import Thread
 import grpc
 
 from api.internal_services.logger import logger
+from api.internal_services.trainer import start_training
 from api.models.Job import JobType
 from api.protos.trainer import trainer_pb2_grpc, trainer_pb2
 
@@ -45,19 +46,7 @@ def process_trainer():
         if job is None:
             break
 
-        with grpc.insecure_channel(job.job_data['server_url']) as channel:
-            stub = trainer_pb2_grpc.TrainerStub(channel)
-            request = trainer_pb2.TrainingInput(
-                training_data=[],
-                fondation_model_id=job.job_data['fondation_model_id'],
-                finetuned_repo_name=job.job_data['finetuned_repo_name'],
-                huggingface_token=job.job_data['huggingface_token'],
-            )
-            logger.debug(request)
-            responses = stub.StartTraining(request)
-            for response in responses:
-                logger.debug(f"gRPC message : {response.status}")
-        logger.debug(f"fin de la connexion gRPC")
+        start_training(job)
 
         logger.info(f"Ending of the job {job.job_id}")
         trainer_queue.task_done()
diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py
new file mode 100644
index 0000000..3a65500
--- /dev/null
+++ b/api/internal_services/trainer.py
@@ -0,0 +1,19 @@
+import grpc
+from api.internal_services.logger import logger
+from api.protos.trainer import trainer_pb2_grpc, trainer_pb2
+
+
+def start_training(job):
+    with grpc.insecure_channel(job.job_data['server_url']) as channel:
+        stub = trainer_pb2_grpc.TrainerStub(channel)
+        request = trainer_pb2.TrainingInput(
+            training_data=[],
+            fondation_model_id=job.job_data['fondation_model_id'],
+            finetuned_repo_name=job.job_data['finetuned_repo_name'],
+            huggingface_token=job.job_data['huggingface_token'],
+        )
+        logger.debug(request)
+        responses = stub.StartTraining(request)
+        for response in responses:
+            logger.debug(f"gRPC message : {response.status}")
+    logger.debug(f"fin de la connexion gRPC")
\ No newline at end of file
diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py
index 0031470..5a0c31e 100644
--- a/microservices/trainer/trainer.py
+++ b/microservices/trainer/trainer.py
@@ -30,12 +30,15 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
             if not is_busy:
                 is_busy = True
                 print(f"incoming request : {request}")
-                training_process(
-                    training_data=request.training_data,
-                    fondation_model_id=request.fondation_model_id,
-                    finetuned_repo_name=request.finetuned_repo_name,
-                    huggingface_token=request.huggingface_token,
-                )
+                try:
+                    training_process(
+                        training_data=request.training_data,
+                        fondation_model_id=request.fondation_model_id,
+                        finetuned_repo_name=request.finetuned_repo_name,
+                        huggingface_token=request.huggingface_token,
+                    )
+                except Exception as e:
+                    print(f"Error : {e}")
                 is_busy = False
             else:
                 print(f"gRPC server is already busy")
-- 
GitLab