Skip to content
Snippets Groups Projects
Commit 813bffc7 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent 380f4f47
Branches
No related tags found
No related merge requests found
syntax = "proto3";
service Trainer {
rpc StartTraining(TrainingInput) returns (stream TrainingEvent){}
rpc StartTraining(TrainingInput) returns (TrainingEvent){}
}
message TrainingInput {
......
......@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
......@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193
_TRAINER._serialized_end=255
_TRAINER._serialized_end=253
# @@protoc_insertion_point(module_scope)
......@@ -14,7 +14,7 @@ class TrainerStub(object):
Args:
channel: A grpc.Channel.
"""
self.StartTraining = channel.unary_stream(
self.StartTraining = channel.unary_unary(
'/Trainer/StartTraining',
request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
......@@ -33,7 +33,7 @@ class TrainerServicer(object):
def add_TrainerServicer_to_server(servicer, server):
rpc_method_handlers = {
'StartTraining': grpc.unary_stream_rpc_method_handler(
'StartTraining': grpc.unary_unary_rpc_method_handler(
servicer.StartTraining,
request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
......@@ -59,7 +59,7 @@ class Trainer(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining',
return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
options, channel_credentials,
......
syntax = "proto3";
service Trainer {
rpc StartTraining(TrainingInput) returns (stream TrainingEvent){}
rpc StartTraining(TrainingInput) returns (TrainingEvent){}
}
message TrainingInput {
......
......@@ -24,29 +24,25 @@ is_busy = False
class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
def StartTraining(self, request, context):
def response_messages():
print("event received")
global is_busy
if not is_busy:
is_busy = True
print(f"incoming request : {request}")
try:
training_process(
training_data=request.training_data,
fondation_model_id=request.fondation_model_id,
finetuned_repo_name=request.finetuned_repo_name,
huggingface_token=request.huggingface_token,
)
except Exception as e:
print(f"Error : {e}")
is_busy = False
else:
print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !")
return response_messages()
print("event received")
global is_busy
if not is_busy:
is_busy = True
print(f"incoming request : {request}")
try:
training_process(
training_data=request.training_data,
fondation_model_id=request.fondation_model_id,
finetuned_repo_name=request.finetuned_repo_name,
huggingface_token=request.huggingface_token,
)
except Exception as e:
print(f"Error : {e}")
is_busy = False
else:
print(f"gRPC server is already busy")
return trainer_pb2.TrainingEvent(status="Training ended successfully !")
def serve():
......
......@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
......@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193
_TRAINER._serialized_end=255
_TRAINER._serialized_end=253
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
from api.protos.trainer import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
class TrainerStub(object):
......@@ -13,7 +14,7 @@ class TrainerStub(object):
Args:
channel: A grpc.Channel.
"""
self.StartTraining = channel.unary_stream(
self.StartTraining = channel.unary_unary(
'/Trainer/StartTraining',
request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
......@@ -32,7 +33,7 @@ class TrainerServicer(object):
def add_TrainerServicer_to_server(servicer, server):
rpc_method_handlers = {
'StartTraining': grpc.unary_stream_rpc_method_handler(
'StartTraining': grpc.unary_unary_rpc_method_handler(
servicer.StartTraining,
request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
......@@ -58,7 +59,7 @@ class Trainer(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining',
return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
options, channel_credentials,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment