diff --git a/api/internal_services/annotator.py b/api/internal_services/annotator.py index ec9eaf4e0bbd446cdfd864112cfd901d34cb535f..35e1eea8f48a812e2508cef67ff5890b3c919fc7 100644 --- a/api/internal_services/annotator.py +++ b/api/internal_services/annotator.py @@ -1,38 +1,77 @@ -from api.internal_services import neo4j +import json +import grpc +from api.internal_services import neo4j, database from api.internal_services.database import update_last_concept_index, get_last_concept_index from api.internal_services.gpt import gpt_process from api.internal_services.logger import logger +from api.protos.inferer import inferer_pb2_grpc, inferer_pb2 is_using_GPT = True def annotation_process(job): - if is_using_GPT: - from api.internal_services.neo4j import driver, create_concept_node, create_concept_relation - with (driver.session() as session): - gpt_annotation = gpt_process(job.job_data['sentence'].text) - for concept, annotations in gpt_annotation.items(): - filtered_annotation = set() - for annotation in annotations: - filtered_annotation = filtered_annotation | neo4j.get_filtered_annotation(job.job_data['sentence_id'], concept, annotation) - if filtered_annotation: - for interval in separate_intervals(filtered_annotation): - session.execute_write( - create_concept_node, - concept, - update_last_concept_index(get_last_concept_index() + 1), - "GPT" - ) - - for word_id in interval: + annotator_config = database.get_annotator_config() + match annotator_config['provider']: + case 'GPT': + from api.internal_services.neo4j import driver, create_concept_node, create_concept_relation + with (driver.session() as session): + gpt_annotation = gpt_process(job.job_data['sentence'].text) + for concept, annotations in gpt_annotation.items(): + filtered_annotation = set() + for annotation in annotations: + filtered_annotation = filtered_annotation | neo4j.get_filtered_annotation( + job.job_data['sentence_id'], concept, annotation) + if filtered_annotation: + for interval in separate_intervals(filtered_annotation): session.execute_write( - create_concept_relation, - get_last_concept_index(), - word_id + create_concept_node, + concept, + update_last_concept_index(get_last_concept_index() + 1), + False ) - else: - gpt_annotation = {} + for word_id in interval: + session.execute_write( + create_concept_relation, + get_last_concept_index(), + word_id + ) + + case 'BERT': + response: inferer_pb2.InferenceResult = None + with grpc.insecure_channel(annotator_config['server_url']) as channel: + stub = inferer_pb2_grpc.InfererStub(channel) + request = inferer_pb2.InferenceInput( + inference_data=job.job_data['sentence'].text, + model_id=annotator_config['model_id'], + ) + response = stub.StartInference(request) + logger.debug(f"Incoming gRPC message : {response.status}") + logger.debug(f"fin de la connexion gRPC") + + response = json.load(response.inference_result) + from api.internal_services.neo4j import driver, get_id_multi_tokens, get_id_single_tokens, create_concept_node, create_concept_relation + with (driver.session() as session): + for span in response: + if " " in span: + word_ids = get_id_multi_tokens(span['text'], job.job_data['sentence_id']) + else: + word_ids = get_id_single_tokens(span['text'], job.job_data['sentence_id']) + + session.execute_write( + create_concept_node, + span['tag'], + update_last_concept_index(get_last_concept_index() + 1), + False + ) + + for word_id in word_ids: + session.execute_write( + create_concept_relation, + get_last_concept_index(), + word_id + ) + def separate_intervals(data): sorted_data = sorted(list(data), key=lambda x: int(x.split('.')[-1])) @@ -52,4 +91,4 @@ def separate_intervals(data): separated_intervals.append(current_interval) - return separated_intervals \ No newline at end of file + return separated_intervals diff --git a/api/internal_services/database.py b/api/internal_services/database.py index 7332a26fe50c7bcc2b55f6cf2d4004174b7124c9..b8a287b60e1aa50704973e7aea995a026eb94ba9 100644 --- a/api/internal_services/database.py +++ b/api/internal_services/database.py @@ -26,4 +26,17 @@ def get_last_concept_index(): def update_last_concept_index(value): db.update({'value': value}, where('key') == 'last_concept_index') + return value + +def get_annotator_config(): + result = db.search(where('key') == 'annotator_config') + if not result: + created_object = {'key': 'annotator_config', 'value': { 'provider': 'GPT' }} + db.insert(created_object) + return 0 + else: + return result[0]['value'] + +def update_annotator_config(value): + db.update({'value': value}, where('key') == 'annotator_config') return value \ No newline at end of file diff --git a/api/internal_services/neo4j.py b/api/internal_services/neo4j.py index df4179524a86de63851a8611916b31502d4174b3..f2823b356b3a974bf30b146a7995b8abff132e99 100644 --- a/api/internal_services/neo4j.py +++ b/api/internal_services/neo4j.py @@ -36,16 +36,16 @@ def create_constituent_node(tx, id, type): id=id, type=type) -def create_concept_node(tx, concept, id, origin): +def create_concept_node(tx, concept, id, used_for_training): tx.run(''' CREATE ( n:Concept { type: $concept, id: $id, - origin: $origin + used_for_training: $used_for_training } )''', - concept=concept, id=id, origin=origin) + concept=concept, id=id, used_for_training=used_for_training) def create_sentence_node(tx, id): @@ -105,6 +105,7 @@ def create_constituent_relation(tx, idFrom, idTo): idFrom=idFrom, idTo=idTo ) + def create_sub_relation(tx, idFrom, idTo): tx.run(''' MATCH @@ -340,15 +341,17 @@ def reference(words_ids): ) return set([record[0] for record in nodes.records]) + def get_sentences_for_train(): nodes = driver.execute_query( ''' match (s:Sentence)--(:Word)--(c:Concept) - where c.origin <> "BERT" + where c.used_for_training = False return s.id as sentence, collect( distinct c.id) as concepts ''' ) - return [(record[0],record[1]) for record in nodes.records] + return [(record[0], record[1]) for record in nodes.records] + def get_full_sentence_by_id(sentence_id): nodes = driver.execute_query( @@ -373,6 +376,7 @@ def get_full_sentence_by_id(sentence_id): ) return [record[0] for record in nodes.records] + def get_concept_text_by_id(concept_id): nodes = driver.execute_query(''' match (c:Concept{id:$concept_id})--(w:Word) @@ -397,15 +401,25 @@ def get_concept_text_by_id(concept_id): RETURN concatenated_text LIMIT 1 ''', - concept_id=concept_id - ) + concept_id=concept_id + ) return [record[0] for record in nodes.records] + def get_concept_type_by_id(concept_id): nodes = driver.execute_query(''' match (c:Concept{id:$concept_id}) return toLower(c.type) ''', - concept_id=concept_id - ) - return [record[0] for record in nodes.records] \ No newline at end of file + concept_id=concept_id + ) + return [record[0] for record in nodes.records] + + +def set_concept_after_training(concept_ids): + nodes = driver.execute_query(''' + MATCH (c:Concept) + WHERE c.id in $concept_ids + SET c.used_for_training = True + RETURN n + ''', concept_ids=concept_ids) diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py index 9b4ab7ec4310e27313b260ea0f3984ef7843c5e2..a833cbc8f26836c0030bbad5fd80d7e78d8650ed 100644 --- a/api/internal_services/trainer.py +++ b/api/internal_services/trainer.py @@ -16,15 +16,17 @@ def find_all_occurrences(text, phrase): def start_training(job): training_data = [] + list_all_concept_ids = [] training_sentences = neo4j.get_sentences_for_train() for tuple in training_sentences: sentence_id = tuple[0] - concepts_id = tuple[1] + concept_ids = tuple[1] sentence = neo4j.get_full_sentence_by_id(sentence_id)[0].strip() tags = [] - for concept_id in concepts_id: + for concept_id in concept_ids: + list_all_concept_ids.append(concept_id) span = neo4j.get_concept_text_by_id(concept_id)[0].strip() type = neo4j.get_concept_type_by_id(concept_id)[0] all_occurrences = list(find_all_occurrences(sentence, span)) @@ -55,3 +57,4 @@ def start_training(job): logger.debug(f"fin de la connexion gRPC") # passer toutes les origines des concepts en BERT + neo4j.set_concept_after_training(list_all_concept_ids) diff --git a/api/protos/inferer/inferer.proto b/api/protos/inferer/inferer.proto new file mode 100644 index 0000000000000000000000000000000000000000..91bde6b8a941daf2db79f178d991f2a9c45686ab --- /dev/null +++ b/api/protos/inferer/inferer.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +service Inferer { + rpc StartInference(InferenceInput) returns (InferenceResult){} +} + +message InferenceInput { + string inference_data = 1; + string model_id = 2; +} + +message InferenceResult { + int32 exit_code = 1; + string status = 2; + string inference_result = 3; +} \ No newline at end of file diff --git a/api/protos/inferer/inferer_pb2.py b/api/protos/inferer/inferer_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b457b010c357a192e6b1ca6e67e8d68ff70abe --- /dev/null +++ b/api/protos/inferer/inferer_pb2.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: api/protos/inferer/inferer.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/inferer/inferer.proto\":\n\x0eInferenceInput\x12\x16\n\x0einference_data\x18\x01 \x01(\t\x12\x10\n\x08model_id\x18\x02 \x01(\t\"N\n\x0fInferenceResult\x12\x11\n\texit_code\x18\x01 \x01(\x05\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x18\n\x10inference_result\x18\x03 \x01(\t2@\n\x07Inferer\x12\x35\n\x0eStartInference\x12\x0f.InferenceInput\x1a\x10.InferenceResult\"\x00\x62\x06proto3') + + + +_INFERENCEINPUT = DESCRIPTOR.message_types_by_name['InferenceInput'] +_INFERENCERESULT = DESCRIPTOR.message_types_by_name['InferenceResult'] +InferenceInput = _reflection.GeneratedProtocolMessageType('InferenceInput', (_message.Message,), { + 'DESCRIPTOR' : _INFERENCEINPUT, + '__module__' : 'api.protos.inferer.inferer_pb2' + # @@protoc_insertion_point(class_scope:InferenceInput) + }) +_sym_db.RegisterMessage(InferenceInput) + +InferenceResult = _reflection.GeneratedProtocolMessageType('InferenceResult', (_message.Message,), { + 'DESCRIPTOR' : _INFERENCERESULT, + '__module__' : 'api.protos.inferer.inferer_pb2' + # @@protoc_insertion_point(class_scope:InferenceResult) + }) +_sym_db.RegisterMessage(InferenceResult) + +_INFERER = DESCRIPTOR.services_by_name['Inferer'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _INFERENCEINPUT._serialized_start=36 + _INFERENCEINPUT._serialized_end=94 + _INFERENCERESULT._serialized_start=96 + _INFERENCERESULT._serialized_end=174 + _INFERER._serialized_start=176 + _INFERER._serialized_end=240 +# @@protoc_insertion_point(module_scope) diff --git a/api/protos/inferer/inferer_pb2_grpc.py b/api/protos/inferer/inferer_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcfa83c37a5d9fefe766e3d9bb9ec55b83c1b5e --- /dev/null +++ b/api/protos/inferer/inferer_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from api.protos.inferer import inferer_pb2 as api_dot_protos_dot_inferer_dot_inferer__pb2 + + +class InfererStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StartInference = channel.unary_unary( + '/Inferer/StartInference', + request_serializer=api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceInput.SerializeToString, + response_deserializer=api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceResult.FromString, + ) + + +class InfererServicer(object): + """Missing associated documentation comment in .proto file.""" + + def StartInference(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_InfererServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StartInference': grpc.unary_unary_rpc_method_handler( + servicer.StartInference, + request_deserializer=api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceInput.FromString, + response_serializer=api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceResult.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'Inferer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Inferer(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def StartInference(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Inferer/StartInference', + api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceInput.SerializeToString, + api_dot_protos_dot_inferer_dot_inferer__pb2.InferenceResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/api/protos/trainer/trainer.proto b/api/protos/trainer/trainer.proto index 46664b1a9dc26b8c93530e9b3796049a4eb84eed..7cf49ae443389ae7562154fb6786a18503569268 100644 --- a/api/protos/trainer/trainer.proto +++ b/api/protos/trainer/trainer.proto @@ -12,5 +12,6 @@ message TrainingInput { } message TrainingEvent { - string status = 1; + int32 exit_code = 1; + string status = 2; } \ No newline at end of file diff --git a/api/protos/trainer/trainer_pb2.py b/api/protos/trainer/trainer_pb2.py index fcad9f190f8e8f113c6dd4d65e44fc3b0613dad6..23a43fdf3b28665d895ef52a4acfbf3530bc6a2e 100644 --- a/api/protos/trainer/trainer_pb2.py +++ b/api/protos/trainer/trainer_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"2\n\rTrainingEvent\x12\x11\n\texit_code\x18\x01 \x01(\x05\x12\x0e\n\x06status\x18\x02 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') @@ -41,7 +41,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TRAININGINPUT._serialized_start=36 _TRAININGINPUT._serialized_end=158 _TRAININGEVENT._serialized_start=160 - _TRAININGEVENT._serialized_end=191 - _TRAINER._serialized_start=193 - _TRAINER._serialized_end=253 + _TRAININGEVENT._serialized_end=210 + _TRAINER._serialized_start=212 + _TRAINER._serialized_end=272 # @@protoc_insertion_point(module_scope) diff --git a/microservices/inferer/Dockerfile b/microservices/inferer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6569d46fbaddcd2f73ba9a1062e35332a0230172 --- /dev/null +++ b/microservices/inferer/Dockerfile @@ -0,0 +1,6 @@ +FROM python:3.10 +WORKDIR /microservices/inferer +COPY ./ /microservices/inferer +RUN pip install setuptools +RUN cd /microservices/inferer && pip install . +CMD ["python", "/microservices/inferer/inferer.py"] \ No newline at end of file diff --git a/microservices/inferer/inferer.proto b/microservices/inferer/inferer.proto new file mode 100644 index 0000000000000000000000000000000000000000..91bde6b8a941daf2db79f178d991f2a9c45686ab --- /dev/null +++ b/microservices/inferer/inferer.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +service Inferer { + rpc StartInference(InferenceInput) returns (InferenceResult){} +} + +message InferenceInput { + string inference_data = 1; + string model_id = 2; +} + +message InferenceResult { + int32 exit_code = 1; + string status = 2; + string inference_result = 3; +} \ No newline at end of file diff --git a/microservices/inferer/inferer.py b/microservices/inferer/inferer.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..25c7d38b2af4707efff0c3e1c3cbc863cc893273 100644 --- a/microservices/inferer/inferer.py +++ b/microservices/inferer/inferer.py @@ -0,0 +1,158 @@ +import json +from concurrent import futures +import grpc +import torch +from transformers import RobertaPreTrainedModel, RobertaModel, AutoTokenizer +from transformers.modeling_outputs import TokenClassifierOutput +from transformers.models.roberta.modeling_roberta import ( + ROBERTA_INPUTS_DOCSTRING, + ROBERTA_START_DOCSTRING, + RobertaEmbeddings, +) +from transformers.utils import add_start_docstrings_to_model_forward +from torch import nn +from typing import Optional, Union, Tuple +import inferer_pb2_grpc +import inferer_pb2 + +is_busy = False + +class InfererServicer(inferer_pb2_grpc.InfererServicer): + def StartInference(self, request, context): + print("event received") + global is_busy + + if not is_busy: + is_busy = True + print(f"incoming request : {request}") + try: + result = inference_process(request.inference_data, request.model_id) + torch.cuda.empty_cache() + is_busy = False + return inferer_pb2.InferenceResult( + exit_code=0, + status="Inference ended successfully !", + inference_result=json.dumps(result) + ) + except Exception as e: + print(f"Error : {e}") + else: + print(f"gRPC server is already busy") + + return inferer_pb2.InferenceResult( + exit_code=1, + status="Inference failed !", + inference_result="" + ) + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + inferer_pb2_grpc.add_InfererServicer_to_server(inferer_pb2_grpc.InfererServicer(), server) + server.add_insecure_port('[::]:80') + server.start() + server.wait_for_termination() + +def inference_process(inference_data, model_id): + + class RobertaForSpanCategorization(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels.float()) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + model = RobertaForSpanCategorization.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + def get_offsets_and_predicted_tags(example: str, model, tokenizer, threshold=0): + raw_encoded_example = tokenizer(example, return_offsets_mapping=True) + encoded_example = tokenizer(example, return_tensors="pt") + out = model(**encoded_example)["logits"][0] + predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out] + + return [{"token": token, "tags": tag, "offset": offset} for (token, tag, offset) + in zip(tokenizer.batch_decode(raw_encoded_example["input_ids"]), + predicted_tags, + raw_encoded_example["offset_mapping"])] + + def get_tagged_groups(sentence: str): + offsets_and_tags = get_offsets_and_predicted_tags(sentence, model, tokenizer) + predicted_offsets = {l: [] for l in tag2id} + last_token_tags = [] + for item in offsets_and_tags: + (start, end), tags = item["offset"], item["tags"] + + for label_id in tags: + tag = id2label[label_id] + if label_id not in last_token_tags and label2id[f"{tag}"] not in last_token_tags: + predicted_offsets[tag].append({"start": start, "end": end}) + else: + predicted_offsets[tag][-1]["end"] = end + + last_token_tags = tags + + flatten_predicted_offsets = [{**v, "tag": k, "text": sentence[v["start"]:v["end"]]} + for k, v_list in predicted_offsets.items() for v in v_list if v["end"] - v["start"] >= 3] + flatten_predicted_offsets = sorted(flatten_predicted_offsets, + key = lambda row: (row["start"], row["end"], row["tag"])) + return flatten_predicted_offsets + + return get_tagged_groups(inference_data) + + +if __name__ == '__main__': + serve() \ No newline at end of file diff --git a/microservices/inferer/inferer_pb2.py b/microservices/inferer/inferer_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b457b010c357a192e6b1ca6e67e8d68ff70abe --- /dev/null +++ b/microservices/inferer/inferer_pb2.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: api/protos/inferer/inferer.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/inferer/inferer.proto\":\n\x0eInferenceInput\x12\x16\n\x0einference_data\x18\x01 \x01(\t\x12\x10\n\x08model_id\x18\x02 \x01(\t\"N\n\x0fInferenceResult\x12\x11\n\texit_code\x18\x01 \x01(\x05\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x18\n\x10inference_result\x18\x03 \x01(\t2@\n\x07Inferer\x12\x35\n\x0eStartInference\x12\x0f.InferenceInput\x1a\x10.InferenceResult\"\x00\x62\x06proto3') + + + +_INFERENCEINPUT = DESCRIPTOR.message_types_by_name['InferenceInput'] +_INFERENCERESULT = DESCRIPTOR.message_types_by_name['InferenceResult'] +InferenceInput = _reflection.GeneratedProtocolMessageType('InferenceInput', (_message.Message,), { + 'DESCRIPTOR' : _INFERENCEINPUT, + '__module__' : 'api.protos.inferer.inferer_pb2' + # @@protoc_insertion_point(class_scope:InferenceInput) + }) +_sym_db.RegisterMessage(InferenceInput) + +InferenceResult = _reflection.GeneratedProtocolMessageType('InferenceResult', (_message.Message,), { + 'DESCRIPTOR' : _INFERENCERESULT, + '__module__' : 'api.protos.inferer.inferer_pb2' + # @@protoc_insertion_point(class_scope:InferenceResult) + }) +_sym_db.RegisterMessage(InferenceResult) + +_INFERER = DESCRIPTOR.services_by_name['Inferer'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _INFERENCEINPUT._serialized_start=36 + _INFERENCEINPUT._serialized_end=94 + _INFERENCERESULT._serialized_start=96 + _INFERENCERESULT._serialized_end=174 + _INFERER._serialized_start=176 + _INFERER._serialized_end=240 +# @@protoc_insertion_point(module_scope) diff --git a/microservices/inferer/setup.py b/microservices/inferer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..35f4ac0d89a5da7c3cf0db30af65e0995f3d2809 --- /dev/null +++ b/microservices/inferer/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name='ALA-Plateform-trainer', + version='0.1', + packages=find_packages(), + install_requires=[ + 'grpcio==1.48.2', + 'grpcio-tools==1.48.2', + 'torch==2.4.0', + 'datasets==2.21.0', + 'transformers[torch]==4.44.1', + 'numpy==2.1.0', + 'scikit-learn==1.5.1', + 'matplotlib==3.9.2' + ], +) diff --git a/microservices/trainer/trainer.proto b/microservices/trainer/trainer.proto index 46664b1a9dc26b8c93530e9b3796049a4eb84eed..7cf49ae443389ae7562154fb6786a18503569268 100644 --- a/microservices/trainer/trainer.proto +++ b/microservices/trainer/trainer.proto @@ -12,5 +12,6 @@ message TrainingInput { } message TrainingEvent { - string status = 1; + int32 exit_code = 1; + string status = 2; } \ No newline at end of file diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index 1b4c0510f450c403fdcadc309513fde0d34f72ba..a9cb85f98946986687a4c49ab375e63c95873c78 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -39,11 +39,19 @@ class TrainerServicer(trainer_pb2_grpc.TrainerServicer): ) except Exception as e: print(f"Error : {e}") + is_busy = False + return trainer_pb2.TrainingEvent( + exit_code=1, + status="Error during the training process !" + ) is_busy = False torch.cuda.empty_cache() else: print(f"gRPC server is already busy") - return trainer_pb2.TrainingEvent(status="Training ended successfully !") + return trainer_pb2.TrainingEvent( + exit_code=0, + status="Training ended successfully !" + ) def serve(): diff --git a/microservices/trainer/trainer_pb2.py b/microservices/trainer/trainer_pb2.py index fcad9f190f8e8f113c6dd4d65e44fc3b0613dad6..23a43fdf3b28665d895ef52a4acfbf3530bc6a2e 100644 --- a/microservices/trainer/trainer_pb2.py +++ b/microservices/trainer/trainer_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"2\n\rTrainingEvent\x12\x11\n\texit_code\x18\x01 \x01(\x05\x12\x0e\n\x06status\x18\x02 \x01(\t2<\n\x07Trainer\x12\x31\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x62\x06proto3') @@ -41,7 +41,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TRAININGINPUT._serialized_start=36 _TRAININGINPUT._serialized_end=158 _TRAININGEVENT._serialized_start=160 - _TRAININGEVENT._serialized_end=191 - _TRAINER._serialized_start=193 - _TRAINER._serialized_end=253 + _TRAININGEVENT._serialized_end=210 + _TRAINER._serialized_start=212 + _TRAINER._serialized_end=272 # @@protoc_insertion_point(module_scope)