Skip to content
Snippets Groups Projects
Commit 813bffc7 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 380f4f47
Branches
No related tags found
No related merge requests found
syntax = "proto3"; syntax = "proto3";
service Trainer { service Trainer {
rpc StartTraining(TrainingInput) returns (stream TrainingEvent){} rpc StartTraining(TrainingInput) returns (TrainingEvent){}
} }
message TrainingInput { message TrainingInput {
......
...@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() ...@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
...@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False: ...@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TRAININGEVENT._serialized_start=160 _TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191 _TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193 _TRAINER._serialized_start=193
_TRAINER._serialized_end=255 _TRAINER._serialized_end=253
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -14,7 +14,7 @@ class TrainerStub(object): ...@@ -14,7 +14,7 @@ class TrainerStub(object):
Args: Args:
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.StartTraining = channel.unary_stream( self.StartTraining = channel.unary_unary(
'/Trainer/StartTraining', '/Trainer/StartTraining',
request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
...@@ -33,7 +33,7 @@ class TrainerServicer(object): ...@@ -33,7 +33,7 @@ class TrainerServicer(object):
def add_TrainerServicer_to_server(servicer, server): def add_TrainerServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'StartTraining': grpc.unary_stream_rpc_method_handler( 'StartTraining': grpc.unary_unary_rpc_method_handler(
servicer.StartTraining, servicer.StartTraining,
request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString, request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString, response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
...@@ -59,7 +59,7 @@ class Trainer(object): ...@@ -59,7 +59,7 @@ class Trainer(object):
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining', return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
options, channel_credentials, options, channel_credentials,
......
syntax = "proto3"; syntax = "proto3";
service Trainer { service Trainer {
rpc StartTraining(TrainingInput) returns (stream TrainingEvent){} rpc StartTraining(TrainingInput) returns (TrainingEvent){}
} }
message TrainingInput { message TrainingInput {
......
...@@ -24,29 +24,25 @@ is_busy = False ...@@ -24,29 +24,25 @@ is_busy = False
class TrainerServicer(trainer_pb2_grpc.TrainerServicer): class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
def StartTraining(self, request, context): def StartTraining(self, request, context):
print("event received")
def response_messages(): global is_busy
print("event received")
global is_busy if not is_busy:
is_busy = True
if not is_busy: print(f"incoming request : {request}")
is_busy = True try:
print(f"incoming request : {request}") training_process(
try: training_data=request.training_data,
training_process( fondation_model_id=request.fondation_model_id,
training_data=request.training_data, finetuned_repo_name=request.finetuned_repo_name,
fondation_model_id=request.fondation_model_id, huggingface_token=request.huggingface_token,
finetuned_repo_name=request.finetuned_repo_name, )
huggingface_token=request.huggingface_token, except Exception as e:
) print(f"Error : {e}")
except Exception as e: is_busy = False
print(f"Error : {e}") else:
is_busy = False print(f"gRPC server is already busy")
else: return trainer_pb2.TrainingEvent(status="Training ended successfully !")
print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !")
return response_messages()
def serve(): def serve():
......
...@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() ...@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
...@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False: ...@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TRAININGEVENT._serialized_start=160 _TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191 _TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193 _TRAINER._serialized_start=193
_TRAINER._serialized_end=255 _TRAINER._serialized_end=253
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services.""" """Client and server classes corresponding to protobuf-defined services."""
import grpc import grpc
import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
from api.protos.trainer import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
class TrainerStub(object): class TrainerStub(object):
...@@ -13,7 +14,7 @@ class TrainerStub(object): ...@@ -13,7 +14,7 @@ class TrainerStub(object):
Args: Args:
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.StartTraining = channel.unary_stream( self.StartTraining = channel.unary_unary(
'/Trainer/StartTraining', '/Trainer/StartTraining',
request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
...@@ -32,7 +33,7 @@ class TrainerServicer(object): ...@@ -32,7 +33,7 @@ class TrainerServicer(object):
def add_TrainerServicer_to_server(servicer, server): def add_TrainerServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'StartTraining': grpc.unary_stream_rpc_method_handler( 'StartTraining': grpc.unary_unary_rpc_method_handler(
servicer.StartTraining, servicer.StartTraining,
request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString, request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString, response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
...@@ -58,7 +59,7 @@ class Trainer(object): ...@@ -58,7 +59,7 @@ class Trainer(object):
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining', return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString, api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
options, channel_credentials, options, channel_credentials,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment