Skip to content
Snippets Groups Projects
Commit be18ab98 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 250acb70
No related branches found
No related tags found
No related merge requests found
import json
import grpc
from api.internal_services import neo4j
from api.internal_services.logger import logger
......@@ -51,9 +50,8 @@ def start_training(job):
finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'],
)
responses = stub.StartTraining(request)
for response in responses:
logger.debug(f"Incoming gRPC message : {response.status}")
response = stub.StartTraining(request)
logger.debug(f"Incoming gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC")
# passer toutes les origines des concepts en BERT
......@@ -40,6 +40,7 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
except Exception as e:
print(f"Error : {e}")
is_busy = False
torch.cuda.empty_cache()
else:
print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !")
......@@ -241,6 +242,10 @@ def training_process(training_data, fondation_model_id, finetuned_repo_name, hug
trainer.model.push_to_hub(finetuned_repo_name, token=huggingface_token)
tokenizer.push_to_hub(finetuned_repo_name, token=huggingface_token)
torch.cuda.empty_cache()
del trainer, tokenizer, train_ds
if __name__ == '__main__':
serve()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment