diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py index 575756a02192d08f96a6c879b9adc17da224de6e..9b4ab7ec4310e27313b260ea0f3984ef7843c5e2 100644 --- a/api/internal_services/trainer.py +++ b/api/internal_services/trainer.py @@ -1,5 +1,4 @@ import json - import grpc from api.internal_services import neo4j from api.internal_services.logger import logger @@ -51,9 +50,8 @@ def start_training(job): finetuned_repo_name=job.job_data['finetuned_repo_name'], huggingface_token=job.job_data['huggingface_token'], ) - responses = stub.StartTraining(request) - for response in responses: - logger.debug(f"Incoming gRPC message : {response.status}") + response = stub.StartTraining(request) + logger.debug(f"Incoming gRPC message : {response.status}") logger.debug(f"fin de la connexion gRPC") # passer toutes les origines des concepts en BERT diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index c532800c204e83e366e77c928cf279db590539dd..1b4c0510f450c403fdcadc309513fde0d34f72ba 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): except Exception as e: print(f"Error : {e}") is_busy = False + torch.cuda.empty_cache() else: print(f"gRPC server is already busy") return trainer_pb2.TrainingEvent(status="Training ended successfully !") @@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token) tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token) + torch.cuda.empty_cache() + + del trainer, tokenizer, train_ds + if __name__ == '__main__': serve()