From 813bffc7c95b57526c8ef7d186838ae6ced7763b Mon Sep 17 00:00:00 2001 From: "Julien B." <xm9q8f80@jlnbrtn.me> Date: Sun, 25 Aug 2024 00:10:09 +0200 Subject: [PATCH] fix(trainer): change trainer function --- api/protos/trainer/trainer.proto | 2 +- api/protos/trainer/trainer_pb2.py | 4 +-- api/protos/trainer/trainer_pb2_grpc.py | 6 ++-- microservices/trainer/trainer.proto | 2 +- microservices/trainer/trainer.py | 42 ++++++++++------------- microservices/trainer/trainer_pb2.py | 4 +-- microservices/trainer/trainer_pb2_grpc.py | 9 ++--- 7 files changed, 33 insertions(+), 36 deletions(-) diff --git a/api/protos/trainer/trainer.proto b/api/protos/trainer/trainer.proto index c36c1f0..46664b1 100644 --- a/api/protos/trainer/trainer.proto +++ b/api/protos/trainer/trainer.proto @@ -1,7 +1,7 @@ syntax = "proto3"; service Trainer { - rpc StartTraining(TrainingInput) returns (stream TrainingEvent){} + rpc StartTraining(TrainingInput) returns (TrainingEvent){} } message TrainingInput { diff --git a/api/protos/trainer/trainer_pb2.py b/api/protos/trainer/trainer_pb2.py index 5a088b5..fcad9f1 100644 --- a/api/protos/trainer/trainer_pb2.py +++ b/api/protos/trainer/trainer_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') @@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TRAININGEVENT._serialized_start=160 _TRAININGEVENT._serialized_end=191 _TRAINER._serialized_start=193 - _TRAINER._serialized_end=255 + _TRAINER._serialized_end=253 # @@protoc_insertion_point(module_scope) diff --git a/api/protos/trainer/trainer_pb2_grpc.py b/api/protos/trainer/trainer_pb2_grpc.py index 489f26d..5277516 100644 --- a/api/protos/trainer/trainer_pb2_grpc.py +++ b/api/protos/trainer/trainer_pb2_grpc.py @@ -14,7 +14,7 @@ class TrainerStub(object): Args: channel: A grpc.Channel. """ - self.StartTraining = channel.unary_stream( + self.StartTraining = channel.unary_unary( '/Trainer/StartTraining', request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, @@ -33,7 +33,7 @@ class TrainerServicer(object): def add_TrainerServicer_to_server(servicer, server): rpc_method_handlers = { - 'StartTraining': grpc.unary_stream_rpc_method_handler( + 'StartTraining': grpc.unary_unary_rpc_method_handler( servicer.StartTraining, request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString, response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString, @@ -59,7 +59,7 @@ class Trainer(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining', + return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining', api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, options, channel_credentials, diff --git a/microservices/trainer/trainer.proto b/microservices/trainer/trainer.proto index c36c1f0..46664b1 100644 --- a/microservices/trainer/trainer.proto +++ b/microservices/trainer/trainer.proto @@ -1,7 +1,7 @@ syntax = "proto3"; service Trainer { - rpc StartTraining(TrainingInput) returns (stream TrainingEvent){} + rpc StartTraining(TrainingInput) returns (TrainingEvent){} } message TrainingInput { diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index 03f5725..c532800 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -24,29 +24,25 @@ is_busy = False class TrainerServicer(trainer_pb2_grpc.TrainerServicer): def StartTraining(self, request, context): - - def response_messages(): - print("event received") - global is_busy - - if not is_busy: - is_busy = True - print(f"incoming request : {request}") - try: - training_process( - training_data=request.training_data, - fondation_model_id=request.fondation_model_id, - finetuned_repo_name=request.finetuned_repo_name, - huggingface_token=request.huggingface_token, - ) - except Exception as e: - print(f"Error : {e}") - is_busy = False - else: - print(f"gRPC server is already busy") - return trainer_pb2.TrainingEvent(status="Training ended successfully !") - - return response_messages() + print("event received") + global is_busy + + if not is_busy: + is_busy = True + print(f"incoming request : {request}") + try: + training_process( + training_data=request.training_data, + fondation_model_id=request.fondation_model_id, + finetuned_repo_name=request.finetuned_repo_name, + huggingface_token=request.huggingface_token, + ) + except Exception as e: + print(f"Error : {e}") + is_busy = False + else: + print(f"gRPC server is already busy") + return trainer_pb2.TrainingEvent(status="Training ended successfully !") def serve(): diff --git a/microservices/trainer/trainer_pb2.py b/microservices/trainer/trainer_pb2.py index 5a088b5..fcad9f1 100644 --- a/microservices/trainer/trainer_pb2.py +++ b/microservices/trainer/trainer_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') @@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TRAININGEVENT._serialized_start=160 _TRAININGEVENT._serialized_end=191 _TRAINER._serialized_start=193 - _TRAINER._serialized_end=255 + _TRAINER._serialized_end=253 # @@protoc_insertion_point(module_scope) diff --git a/microservices/trainer/trainer_pb2_grpc.py b/microservices/trainer/trainer_pb2_grpc.py index ca55b8c..5277516 100644 --- a/microservices/trainer/trainer_pb2_grpc.py +++ b/microservices/trainer/trainer_pb2_grpc.py @@ -1,7 +1,8 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc -import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2 + +from api.protos.trainer import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2 class TrainerStub(object): @@ -13,7 +14,7 @@ class TrainerStub(object): Args: channel: A grpc.Channel. """ - self.StartTraining = channel.unary_stream( + self.StartTraining = channel.unary_unary( '/Trainer/StartTraining', request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, @@ -32,7 +33,7 @@ class TrainerServicer(object): def add_TrainerServicer_to_server(servicer, server): rpc_method_handlers = { - 'StartTraining': grpc.unary_stream_rpc_method_handler( + 'StartTraining': grpc.unary_unary_rpc_method_handler( servicer.StartTraining, request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString, response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString, @@ -58,7 +59,7 @@ class Trainer(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining', + return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining', api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, options, channel_credentials, -- GitLab