From 813bffc7c95b57526c8ef7d186838ae6ced7763b Mon Sep 17 00:00:00 2001
From: "Julien B." <xm9q8f80@jlnbrtn.me>
Date: Sun, 25 Aug 2024 00:10:09 +0200
Subject: [PATCH] fix(trainer): change trainer function

---
 api/protos/trainer/trainer.proto          |  2 +-
 api/protos/trainer/trainer_pb2.py         |  4 +--
 api/protos/trainer/trainer_pb2_grpc.py    |  6 ++--
 microservices/trainer/trainer.proto       |  2 +-
 microservices/trainer/trainer.py          | 42 ++++++++++-------------
 microservices/trainer/trainer_pb2.py      |  4 +--
 microservices/trainer/trainer_pb2_grpc.py |  9 ++---
 7 files changed, 33 insertions(+), 36 deletions(-)

diff --git a/api/protos/trainer/trainer.proto b/api/protos/trainer/trainer.proto
index c36c1f0..46664b1 100644
--- a/api/protos/trainer/trainer.proto
+++ b/api/protos/trainer/trainer.proto
@@ -1,7 +1,7 @@
 syntax = "proto3";
 
 service Trainer {
-  rpc StartTraining(TrainingInput) returns (stream TrainingEvent){}
+  rpc StartTraining(TrainingInput) returns (TrainingEvent){}
 }
 
 message TrainingInput {
diff --git a/api/protos/trainer/trainer_pb2.py b/api/protos/trainer/trainer_pb2.py
index 5a088b5..fcad9f1 100644
--- a/api/protos/trainer/trainer_pb2.py
+++ b/api/protos/trainer/trainer_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
 
 
 
@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _TRAININGEVENT._serialized_start=160
   _TRAININGEVENT._serialized_end=191
   _TRAINER._serialized_start=193
-  _TRAINER._serialized_end=255
+  _TRAINER._serialized_end=253
 # @@protoc_insertion_point(module_scope)
diff --git a/api/protos/trainer/trainer_pb2_grpc.py b/api/protos/trainer/trainer_pb2_grpc.py
index 489f26d..5277516 100644
--- a/api/protos/trainer/trainer_pb2_grpc.py
+++ b/api/protos/trainer/trainer_pb2_grpc.py
@@ -14,7 +14,7 @@ class TrainerStub(object):
         Args:
             channel: A grpc.Channel.
         """
-        self.StartTraining = channel.unary_stream(
+        self.StartTraining = channel.unary_unary(
                 '/Trainer/StartTraining',
                 request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
                 response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
@@ -33,7 +33,7 @@ class TrainerServicer(object):
 
 def add_TrainerServicer_to_server(servicer, server):
     rpc_method_handlers = {
-            'StartTraining': grpc.unary_stream_rpc_method_handler(
+            'StartTraining': grpc.unary_unary_rpc_method_handler(
                     servicer.StartTraining,
                     request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
                     response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
@@ -59,7 +59,7 @@ class Trainer(object):
             wait_for_ready=None,
             timeout=None,
             metadata=None):
-        return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining',
+        return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
             api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
             api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
             options, channel_credentials,
diff --git a/microservices/trainer/trainer.proto b/microservices/trainer/trainer.proto
index c36c1f0..46664b1 100644
--- a/microservices/trainer/trainer.proto
+++ b/microservices/trainer/trainer.proto
@@ -1,7 +1,7 @@
 syntax = "proto3";
 
 service Trainer {
-  rpc StartTraining(TrainingInput) returns (stream TrainingEvent){}
+  rpc StartTraining(TrainingInput) returns (TrainingEvent){}
 }
 
 message TrainingInput {
diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py
index 03f5725..c532800 100644
--- a/microservices/trainer/trainer.py
+++ b/microservices/trainer/trainer.py
@@ -24,29 +24,25 @@ is_busy = False
 
 class TrainerServicer(trainer_pb2_grpc.TrainerServicer):
     def StartTraining(self, request, context):
-
-        def response_messages():
-            print("event received")
-            global is_busy
-
-            if not is_busy:
-                is_busy = True
-                print(f"incoming request : {request}")
-                try:
-                    training_process(
-                        training_data=request.training_data,
-                        fondation_model_id=request.fondation_model_id,
-                        finetuned_repo_name=request.finetuned_repo_name,
-                        huggingface_token=request.huggingface_token,
-                    )
-                except Exception as e:
-                    print(f"Error : {e}")
-                is_busy = False
-            else:
-                print(f"gRPC server is already busy")
-            return trainer_pb2.TrainingEvent(status="Training ended successfully !")
-
-        return response_messages()
+        print("event received")
+        global is_busy
+
+        if not is_busy:
+            is_busy = True
+            print(f"incoming request : {request}")
+            try:
+                training_process(
+                    training_data=request.training_data,
+                    fondation_model_id=request.fondation_model_id,
+                    finetuned_repo_name=request.finetuned_repo_name,
+                    huggingface_token=request.huggingface_token,
+                )
+            except Exception as e:
+                print(f"Error : {e}")
+            is_busy = False
+        else:
+            print(f"gRPC server is already busy")
+        return trainer_pb2.TrainingEvent(status="Training ended successfully !")
 
 
 def serve():
diff --git a/microservices/trainer/trainer_pb2.py b/microservices/trainer/trainer_pb2.py
index 5a088b5..fcad9f1 100644
--- a/microservices/trainer/trainer_pb2.py
+++ b/microservices/trainer/trainer_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3')
 
 
 
@@ -43,5 +43,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _TRAININGEVENT._serialized_start=160
   _TRAININGEVENT._serialized_end=191
   _TRAINER._serialized_start=193
-  _TRAINER._serialized_end=255
+  _TRAINER._serialized_end=253
 # @@protoc_insertion_point(module_scope)
diff --git a/microservices/trainer/trainer_pb2_grpc.py b/microservices/trainer/trainer_pb2_grpc.py
index ca55b8c..5277516 100644
--- a/microservices/trainer/trainer_pb2_grpc.py
+++ b/microservices/trainer/trainer_pb2_grpc.py
@@ -1,7 +1,8 @@
 # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
 """Client and server classes corresponding to protobuf-defined services."""
 import grpc
-import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
+
+from api.protos.trainer import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
 
 
 class TrainerStub(object):
@@ -13,7 +14,7 @@ class TrainerStub(object):
         Args:
             channel: A grpc.Channel.
         """
-        self.StartTraining = channel.unary_stream(
+        self.StartTraining = channel.unary_unary(
                 '/Trainer/StartTraining',
                 request_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
                 response_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
@@ -32,7 +33,7 @@ class TrainerServicer(object):
 
 def add_TrainerServicer_to_server(servicer, server):
     rpc_method_handlers = {
-            'StartTraining': grpc.unary_stream_rpc_method_handler(
+            'StartTraining': grpc.unary_unary_rpc_method_handler(
                     servicer.StartTraining,
                     request_deserializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.FromString,
                     response_serializer=api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.SerializeToString,
@@ -58,7 +59,7 @@ class Trainer(object):
             wait_for_ready=None,
             timeout=None,
             metadata=None):
-        return grpc.experimental.unary_stream(request, target, '/Trainer/StartTraining',
+        return grpc.experimental.unary_unary(request, target, '/Trainer/StartTraining',
             api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingInput.SerializeToString,
             api_dot_protos_dot_trainer_dot_trainer__pb2.TrainingEvent.FromString,
             options, channel_credentials,
-- 
GitLab