From be18ab98ec4b5707dfc083da5c384f9948be0fe5 Mon Sep 17 00:00:00 2001
From: "Julien B." <xm9q8f80@jlnbrtn.me>
Date: Sun, 25 Aug 2024 01:12:23 +0200
Subject: [PATCH] fix(trainer): change trainer function

---
 api/internal_services/trainer.py | 6 ++----
 microservices/trainer/trainer.py | 5 +++++
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py
index 575756a..9b4ab7e 100644
--- a/api/internal_services/trainer.py
+++ b/api/internal_services/trainer.py
@@ -1,5 +1,4 @@
 import json
-
 import grpc
 from api.internal_services import neo4j
 from api.internal_services.logger import logger
@@ -51,9 +50,8 @@ def start_training(job):
             finetuned_repo_name=job.job_data['finetuned_repo_name'],
             huggingface_token=job.job_data['huggingface_token'],
         )
-        responses = stub.StartTraining(request)
-        for response in responses:
-            logger.debug(f"Incoming gRPC message : {response.status}")
+        response = stub.StartTraining(request)
+        logger.debug(f"Incoming gRPC message : {response.status}")
     logger.debug(f"fin de la connexion gRPC")
 
     # passer toutes les origines des concepts en BERT
diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py
index c532800..1b4c051 100644
--- a/microservices/trainer/trainer.py
+++ b/microservices/trainer/trainer.py
@@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
             except Exception as e:
                 print(f"Error : {e}")
             is_busy = False
+            torch.cuda.empty_cache()
         else:
             print(f"gRPC server is already busy")
         return trainer_pb2.TrainingEvent(status="Training ended successfully !")
@@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug
     trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token)
     tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token)
 
+    torch.cuda.empty_cache()
+
+    del trainer, tokenizer, train_ds
+
 
 if __name__ == '__main__':
     serve()
-- 
GitLab