Skip to content
Snippets Groups Projects
Commit cbc4c653 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 40579d40
Branches
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ from threading import Thread ...@@ -4,6 +4,7 @@ from threading import Thread
import grpc import grpc
from api.internal_services.logger import logger from api.internal_services.logger import logger
from api.internal_services.trainer import start_training
from api.models.Job import JobType from api.models.Job import JobType
from api.protos.trainer import trainer_pb2_grpc, trainer_pb2 from api.protos.trainer import trainer_pb2_grpc, trainer_pb2
...@@ -45,19 +46,7 @@ def process_trainer(): ...@@ -45,19 +46,7 @@ def process_trainer():
if job is None: if job is None:
break break
with grpc.insecure_channel(job.job_data['server_url']) as channel: start_training(job)
stub = trainer_pb2_grpc.TrainerStub(channel)
request = trainer_pb2.TrainingInput(
training_data=[],
fondation_model_id=job.job_data['fondation_model_id'],
finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'],
)
logger.debug(request)
responses = stub.StartTraining(request)
for response in responses:
logger.debug(f"gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC")
logger.info(f"Ending of the job {job.job_id}") logger.info(f"Ending of the job {job.job_id}")
trainer_queue.task_done() trainer_queue.task_done()
......
import grpc
from api.internal_services.logger import logger
from api.protos.trainer import trainer_pb2_grpc, trainer_pb2
def start_training(job):
with grpc.insecure_channel(job.job_data['server_url']) as channel:
stub = trainer_pb2_grpc.TrainerStub(channel)
request = trainer_pb2.TrainingInput(
training_data=[],
fondation_model_id=job.job_data['fondation_model_id'],
finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'],
)
logger.debug(request)
responses = stub.StartTraining(request)
for response in responses:
logger.debug(f"gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC")
\ No newline at end of file
...@@ -30,12 +30,15 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): ...@@ -30,12 +30,15 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
if not is_busy: if not is_busy:
is_busy = True is_busy = True
print(f"incoming request : {request}") print(f"incoming request : {request}")
training_process( try:
training_data=request.training_data, training_process(
fondation_model_id=request.fondation_model_id, training_data=request.training_data,
finetuned_repo_name=request.finetuned_repo_name, fondation_model_id=request.fondation_model_id,
huggingface_token=request.huggingface_token, finetuned_repo_name=request.finetuned_repo_name,
) huggingface_token=request.huggingface_token,
)
except Exception as e:
print(f"Error : {e}")
is_busy = False is_busy = False
else: else:
print(f"gRPC server is already busy") print(f"gRPC server is already busy")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment