Skip to content
Snippets Groups Projects
Commit be18ab98 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 250acb70
Branches
No related tags found
No related merge requests found
import json import json
import grpc import grpc
from api.internal_services import neo4j from api.internal_services import neo4j
from api.internal_services.logger import logger from api.internal_services.logger import logger
...@@ -51,9 +50,8 @@ def start_training(job): ...@@ -51,9 +50,8 @@ def start_training(job):
finetuned_repo_name=job.job_data['finetuned_repo_name'], finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'], huggingface_token=job.job_data['huggingface_token'],
) )
responses = stub.StartTraining(request) response = stub.StartTraining(request)
for response in responses: logger.debug(f"Incoming gRPC message : {response.status}")
logger.debug(f"Incoming gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC") logger.debug(f"fin de la connexion gRPC")
# passer toutes les origines des concepts en BERT # passer toutes les origines des concepts en BERT
...@@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): ...@@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
except Exception as e: except Exception as e:
print(f"Error : {e}") print(f"Error : {e}")
is_busy = False is_busy = False
torch.cuda.empty_cache()
else: else:
print(f"gRPC server is already busy") print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !") return trainer_pb2.TrainingEvent(status="Training ended successfully !")
...@@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug ...@@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug
trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token) trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token)
tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token) tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token)
torch.cuda.empty_cache()
del trainer, tokenizer, train_ds
if __name__ == '__main__': if __name__ == '__main__':
serve() serve()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment