diff --git a/api/internal_services/neo4j.py b/api/internal_services/neo4j.py index f7b4c86b461b0690050473af9c5f25428b8329ea..df4179524a86de63851a8611916b31502d4174b3 100644 --- a/api/internal_services/neo4j.py +++ b/api/internal_services/neo4j.py @@ -3,8 +3,7 @@ import os from neo4j import GraphDatabase from api.internal_services.logger import logger - -uri = os.getenv('NEO4J_URL', "bolt://ala-neo4j:7687") +uri = os.getenv('NEO4J_URL', "bolt://localhost:7687") username = os.getenv('NEO4J_USERNAME', "neo4j") password = os.getenv('NEO4J_PASSWORD', "password") @@ -49,6 +48,16 @@ def create_concept_node(tx, concept, id, origin): concept=concept, id=id, origin=origin) +def create_sentence_node(tx, id): + tx.run(''' + CREATE ( + n:Sentence { + id: $id + } + )''', + id=id) + + def create_next_word_relation(tx, idFrom, idTo): tx.run(''' MATCH @@ -96,16 +105,15 @@ def create_constituent_relation(tx, idFrom, idTo): idFrom=idFrom, idTo=idTo ) - -def create_relation(tx, idFrom, idTo, id, type): +def create_sub_relation(tx, idFrom, idTo): tx.run(''' MATCH - (a:Concept), - (b:Concept) - WHERE a.id = $idFrom AND b.id = $idTo - CREATE (a)-[r:RELATION {id: $id, type: $type}]->(b) + (w:Word), + (s:Sentence) + WHERE w.id = $idFrom AND s.id = $idTo + CREATE (w)-[r:SUB]->(s) ''', - idFrom=idFrom, idTo=idTo, id=id, type=type + idFrom=idFrom, idTo=idTo ) @@ -120,7 +128,8 @@ def get_filtered_annotation(sentence_id, concept, annotation): words_ids = get_id_single_tokens(annotation, sentence_id) if len(words_ids) == 0: - logger.warn(f"Cannot find the following annotation '{annotation}' in the sentence id {sentence_id}. This error is a hallucination of large language models.") + logger.warn( + f"Cannot find the following annotation '{annotation}' in the sentence id {sentence_id}. This error is a hallucination of large language models.") return set() filtered_annotation = set() @@ -150,8 +159,9 @@ def get_id_multi_tokens(annotation, sentence_id): ''' WITH $array AS words MATCH path = (start:Word)-[:NEXT*]->(end:Word) + MATCH (start)--(s:Sentence) where size(words) - 1 = size(relationships(path)) - and start.id starts with $sentence_id + and s.id = $sentence_id and all( idx IN range(0, size(words)-2) WHERE (toLower(words[idx]) = toLower((nodes(path)[idx]).text) @@ -164,7 +174,7 @@ def get_id_multi_tokens(annotation, sentence_id): return collect(results.id) as liste ''', array=annotation, - sentence_id=f"{sentence_id}." + sentence_id=sentence_id ) return list_multi_token.records[0][0] @@ -172,14 +182,14 @@ def get_id_multi_tokens(annotation, sentence_id): def get_id_single_tokens(annotation, sentence_id): list_single_token = driver.execute_query( ''' - match (w:Word) + match (s:Sentence)--(w:Word) where toLower(w.text) = $annotation - and w.id starts with $sentence_id + and s.id = $sentence_id with distinct w as results return collect(results.id) as liste ''', annotation=annotation, - sentence_id=f"{sentence_id}." + sentence_id=sentence_id ) return list_single_token.records[0][0] @@ -329,3 +339,73 @@ def reference(words_ids): array=list(words_ids) ) return set([record[0] for record in nodes.records]) + +def get_sentences_for_train(): + nodes = driver.execute_query( + ''' + match (s:Sentence)--(:Word)--(c:Concept) + where c.origin <> "BERT" + return s.id as sentence, collect( distinct c.id) as concepts + ''' + ) + return [(record[0],record[1]) for record in nodes.records] + +def get_full_sentence_by_id(sentence_id): + nodes = driver.execute_query( + ''' + MATCH (s:Sentence{id:$sentence_id})--(w1:Word) + MATCH (s:Sentence{id:$sentence_id})--(w2:Word) + WHERE NOT EXISTS { + MATCH (:Word)-[:NEXT]->(w1) + } + and not exists { + match (w2)-[:NEXT]->(:Word) + } + WITH w1 AS startNode, w2 as endNode + MATCH path = (startNode)-[:NEXT*]->(endNode) + WITH nodes(path) AS nodelist + WITH [node IN nodelist | node.text] AS texts + WITH reduce(acc = "", t IN texts | acc + " " + t) AS concatenated_text + RETURN concatenated_text + LIMIT 1 + ''', + sentence_id=sentence_id + ) + return [record[0] for record in nodes.records] + +def get_concept_text_by_id(concept_id): + nodes = driver.execute_query(''' + match (c:Concept{id:$concept_id})--(w:Word) + with collect(w) as word_list + match (startNode:Word) + match (endNode:Word) + where startNode in word_list + and endNode in word_list + and not exists { + match (w1:Word)-[:NEXT]->(startNode) + where w1 in word_list + } + and not exists { + match (endNode)-[:NEXT]->(w2:Word) + where w2 in word_list + } + with startNode, endNode + MATCH path = (startNode)-[:NEXT*]->(endNode) + WITH nodes(path) AS nodelist + WITH [node IN nodelist | node.text] AS texts + WITH reduce(acc = "", t IN texts | acc + " " + t) AS concatenated_text + RETURN concatenated_text + LIMIT 1 + ''', + concept_id=concept_id + ) + return [record[0] for record in nodes.records] + +def get_concept_type_by_id(concept_id): + nodes = driver.execute_query(''' + match (c:Concept{id:$concept_id}) + return toLower(c.type) + ''', + concept_id=concept_id + ) + return [record[0] for record in nodes.records] \ No newline at end of file diff --git a/api/internal_services/spacy.py b/api/internal_services/spacy.py index 43f9b649be3cf9e5f32cf2d78699ea0f7350105d..c308c9f3286eff50dda29f45e5dcdc0d78e7b41b 100644 --- a/api/internal_services/spacy.py +++ b/api/internal_services/spacy.py @@ -1,3 +1,5 @@ +import uuid + import benepar, spacy import warnings @@ -8,7 +10,7 @@ warnings.filterwarnings("ignore") from api.internal_services.database import get_last_sentence_index, update_last_sentence_index from api.internal_services.neo4j import create_constituent_node, driver, create_constituent_relation, create_word_node, \ - create_next_word_relation, create_deprel_relation + create_next_word_relation, create_deprel_relation, create_sentence_node, create_sub_relation benepar.download('benepar_fr2') nlp = spacy.load('fr_dep_news_trf') @@ -17,14 +19,21 @@ nlp.add_pipe('benepar', config={'model': 'benepar_fr2'}) def parsing_and_load_in_neo4j(job): sentence = job.job_data['sentence'] - last_index = get_last_sentence_index() - last_index = update_last_sentence_index(last_index + 1) + last_index = update_last_sentence_index(get_last_sentence_index() + 1) with (driver.session() as session): doc = nlp(sentence) for i, sentence in enumerate(doc.sents): + sentence_uuid = str(uuid.uuid4()) + + #Création du noeud sentence + session.execute_write( + create_sentence_node, + sentence_uuid + ) + for constituent in sentence._.constituents: constituent_id = f"{last_index}.{i}.{constituent.start}-{constituent.end}" if constituent._.labels and constituent.root.text != constituent.text: @@ -44,6 +53,16 @@ def parsing_and_load_in_neo4j(job): ) else: + #Création du mot en noeud neo4j + session.execute_write( + create_word_node, + '.'.join(map(str, [last_index, i, constituent.root.i])), + constituent.text, + None if not hasattr(constituent, 'lemma_') else constituent.lemma_, + constituent.root.pos_, + True if constituent.root.dep_ == "root" else False + ) + # Créer le mot et le constituant solitaire si nécessaire if constituent._.labels: # Créer le consituant @@ -53,16 +72,6 @@ def parsing_and_load_in_neo4j(job): constituent._.labels[0] ) - #Création du mot en noeud neo4j - session.execute_write( - create_word_node, - '.'.join(map(str, [last_index, i, constituent.root.i])), - constituent.text, - None if not hasattr(constituent, 'lemma_') else constituent.lemma_, - constituent.root.pos_, - True if constituent.root.dep_ == "root" else False - ) - session.execute_write( create_constituent_relation, f"{last_index}.{i}.{constituent.start}-{constituent.end}", @@ -76,16 +85,6 @@ def parsing_and_load_in_neo4j(job): ) else: - #Création du mot en noeud neo4j - session.execute_write( - create_word_node, - '.'.join(map(str, [last_index, i, constituent.root.i])), - constituent.text, - None if not hasattr(constituent, 'lemma_') else constituent.lemma_, - constituent.root.pos_, - True if constituent.root.dep_ == "root" else False - ) - # parent existe alors on crée le lien session.execute_write( create_constituent_relation, @@ -93,6 +92,12 @@ def parsing_and_load_in_neo4j(job): '.'.join(map(str, [last_index, i, constituent.root.i])), ) + session.execute_write( + create_sub_relation, + '.'.join(map(str, [last_index, i, constituent.root.i])), + sentence_uuid, + ) + for token in sentence: #Création d'un lien de succession if token.i != 0: @@ -108,7 +113,7 @@ def parsing_and_load_in_neo4j(job): new_job = Job() new_job.job_id = job.job_id new_job.job_type = JobType.ANNOTATION - new_job.job_data = {'sentence': sentence, 'sentence_id': last_index} + new_job.job_data = {'sentence': sentence, 'sentence_id': sentence_uuid} add_job_to_queue(new_job) diff --git a/api/internal_services/trainer.py b/api/internal_services/trainer.py index 3a65500c0695e130f0e6aa96a7cfd45b857cb5aa..1f870b9fe62525105fd5ee010aeedd268f08da07 100644 --- a/api/internal_services/trainer.py +++ b/api/internal_services/trainer.py @@ -1,13 +1,53 @@ +import json + import grpc +from api.internal_services import neo4j from api.internal_services.logger import logger from api.protos.trainer import trainer_pb2_grpc, trainer_pb2 +def find_all_occurrences(text, phrase): + start = 0 + while True: + start = text.find(phrase, start) + if start == -1: return + yield start + start += len(phrase) + + def start_training(job): + training_data = [] + training_sentences = neo4j.get_sentences_for_train() + + for tuple in training_sentences: + sentence_id = tuple[0] + concepts_id = tuple[1] + sentence = neo4j.get_full_sentence_by_id(sentence_id)[0].strip() + tags = [] + + for concept_id in concepts_id: + span = neo4j.get_concept_text_by_id(concept_id)[0].strip() + type = neo4j.get_concept_type_by_id(concept_id)[0] + all_occurrences = list(find_all_occurrences(sentence, span)) + for start_index in all_occurrences: + end_index = start_index + len(span) + tags.append({ + "start": start_index if start_index == 0 else start_index-1, + "end": end_index, + "tag": type, + }) + + training_data.append({ + 'id': sentence_id, + 'text': sentence, + 'tags': tags + }) + + logger.debug(training_data) with grpc.insecure_channel(job.job_data['server_url']) as channel: stub = trainer_pb2_grpc.TrainerStub(channel) request = trainer_pb2.TrainingInput( - training_data=[], + training_data=json.dumps(training_data), fondation_model_id=job.job_data['fondation_model_id'], finetuned_repo_name=job.job_data['finetuned_repo_name'], huggingface_token=job.job_data['huggingface_token'], @@ -16,4 +56,6 @@ def start_training(job): responses = stub.StartTraining(request) for response in responses: logger.debug(f"gRPC message : {response.status}") - logger.debug(f"fin de la connexion gRPC") \ No newline at end of file + logger.debug(f"fin de la connexion gRPC") + + # passer toutes les origines des concepts en BERT diff --git a/api/protos/trainer/trainer.proto b/api/protos/trainer/trainer.proto index 6510cc54d6bf98d980eaa7931d88aac05c3970db..c36c1f016771eedf304ad29a63b24dacdae81545 100644 --- a/api/protos/trainer/trainer.proto +++ b/api/protos/trainer/trainer.proto @@ -5,24 +5,12 @@ service Trainer { } message TrainingInput { - repeated TrainingData training_data = 1; + string training_data = 1; string fondation_model_id = 2; string finetuned_repo_name = 3; string huggingface_token = 4; } -message TrainingData { - string id = 1; - string text = 2; - repeated Tag tags = 3; -} - -message Tag { - int32 start = 1; - int32 end = 2; - string tag = 3; -} - message TrainingEvent { string status = 1; } \ No newline at end of file diff --git a/api/protos/trainer/trainer_pb2.py b/api/protos/trainer/trainer_pb2.py index 4e4c61cf7a37e411e1e14d66d59f1f0fab5ae210..5a088b5d1f816f47c73b05ebb5590b447c7f818f 100644 --- a/api/protos/trainer/trainer_pb2.py +++ b/api/protos/trainer/trainer_pb2.py @@ -14,13 +14,11 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"\x89\x01\n\rTrainingInput\x12$\n\rtraining_data\x18\x01 \x03(\x0b\x32\r.TrainingData\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"<\n\x0cTrainingData\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x12\n\x04tags\x18\x03 \x03(\x0b\x32\x04.Tag\".\n\x03Tag\x12\r\n\x05start\x18\x01 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x02 \x01(\x05\x12\x0b\n\x03tag\x18\x03 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') _TRAININGINPUT = DESCRIPTOR.message_types_by_name['TrainingInput'] -_TRAININGDATA = DESCRIPTOR.message_types_by_name['TrainingData'] -_TAG = DESCRIPTOR.message_types_by_name['Tag'] _TRAININGEVENT = DESCRIPTOR.message_types_by_name['TrainingEvent'] TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_message.Message,), { 'DESCRIPTOR' : _TRAININGINPUT, @@ -29,20 +27,6 @@ TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_mess }) _sym_db.RegisterMessage(TrainingInput) -TrainingData = _reflection.GeneratedProtocolMessageType('TrainingData', (_message.Message,), { - 'DESCRIPTOR' : _TRAININGDATA, - '__module__' : 'api.protos.trainer.trainer_pb2' - # @@protoc_insertion_point(class_scope:TrainingData) - }) -_sym_db.RegisterMessage(TrainingData) - -Tag = _reflection.GeneratedProtocolMessageType('Tag', (_message.Message,), { - 'DESCRIPTOR' : _TAG, - '__module__' : 'api.protos.trainer.trainer_pb2' - # @@protoc_insertion_point(class_scope:Tag) - }) -_sym_db.RegisterMessage(Tag) - TrainingEvent = _reflection.GeneratedProtocolMessageType('TrainingEvent', (_message.Message,), { 'DESCRIPTOR' : _TRAININGEVENT, '__module__' : 'api.protos.trainer.trainer_pb2' @@ -54,14 +38,10 @@ _TRAINER = DESCRIPTOR.services_by_name['Trainer'] if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _TRAININGINPUT._serialized_start=37 - _TRAININGINPUT._serialized_end=174 - _TRAININGDATA._serialized_start=176 - _TRAININGDATA._serialized_end=236 - _TAG._serialized_start=238 - _TAG._serialized_end=284 - _TRAININGEVENT._serialized_start=286 - _TRAININGEVENT._serialized_end=317 - _TRAINER._serialized_start=319 - _TRAINER._serialized_end=381 + _TRAININGINPUT._serialized_start=36 + _TRAININGINPUT._serialized_end=158 + _TRAININGEVENT._serialized_start=160 + _TRAININGEVENT._serialized_end=191 + _TRAINER._serialized_start=193 + _TRAINER._serialized_end=255 # @@protoc_insertion_point(module_scope) diff --git a/microservices/trainer/trainer.proto b/microservices/trainer/trainer.proto index 6510cc54d6bf98d980eaa7931d88aac05c3970db..c36c1f016771eedf304ad29a63b24dacdae81545 100644 --- a/microservices/trainer/trainer.proto +++ b/microservices/trainer/trainer.proto @@ -5,24 +5,12 @@ service Trainer { } message TrainingInput { - repeated TrainingData training_data = 1; + string training_data = 1; string fondation_model_id = 2; string finetuned_repo_name = 3; string huggingface_token = 4; } -message TrainingData { - string id = 1; - string text = 2; - repeated Tag tags = 3; -} - -message Tag { - int32 start = 1; - int32 end = 2; - string tag = 3; -} - message TrainingEvent { string status = 1; } \ No newline at end of file diff --git a/microservices/trainer/trainer.py b/microservices/trainer/trainer.py index 5a0c31e987e7517dd258956c5184a06ca49f77e7..ad423312cac9e92ac7083418fad68eed9d828b1b 100644 --- a/microservices/trainer/trainer.py +++ b/microservices/trainer/trainer.py @@ -1,3 +1,4 @@ +import json from concurrent import futures from typing import Optional, Union, Tuple import grpc @@ -55,6 +56,8 @@ def serve(): def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token): + training_data = json.loads(training_data) + print(training_data) MAX_LENGTH = 256 tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7, 'time': 8} diff --git a/microservices/trainer/trainer_pb2.py b/microservices/trainer/trainer_pb2.py index 4e4c61cf7a37e411e1e14d66d59f1f0fab5ae210..5a088b5d1f816f47c73b05ebb5590b447c7f818f 100644 --- a/microservices/trainer/trainer_pb2.py +++ b/microservices/trainer/trainer_pb2.py @@ -14,13 +14,11 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"\x89\x01\n\rTrainingInput\x12$\n\rtraining_data\x18\x01 \x03(\x0b\x32\r.TrainingData\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"<\n\x0cTrainingData\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x12\n\x04tags\x18\x03 \x03(\x0b\x32\x04.Tag\".\n\x03Tag\x12\r\n\x05start\x18\x01 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x02 \x01(\x05\x12\x0b\n\x03tag\x18\x03 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3') _TRAININGINPUT = DESCRIPTOR.message_types_by_name['TrainingInput'] -_TRAININGDATA = DESCRIPTOR.message_types_by_name['TrainingData'] -_TAG = DESCRIPTOR.message_types_by_name['Tag'] _TRAININGEVENT = DESCRIPTOR.message_types_by_name['TrainingEvent'] TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_message.Message,), { 'DESCRIPTOR' : _TRAININGINPUT, @@ -29,20 +27,6 @@ TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_mess }) _sym_db.RegisterMessage(TrainingInput) -TrainingData = _reflection.GeneratedProtocolMessageType('TrainingData', (_message.Message,), { - 'DESCRIPTOR' : _TRAININGDATA, - '__module__' : 'api.protos.trainer.trainer_pb2' - # @@protoc_insertion_point(class_scope:TrainingData) - }) -_sym_db.RegisterMessage(TrainingData) - -Tag = _reflection.GeneratedProtocolMessageType('Tag', (_message.Message,), { - 'DESCRIPTOR' : _TAG, - '__module__' : 'api.protos.trainer.trainer_pb2' - # @@protoc_insertion_point(class_scope:Tag) - }) -_sym_db.RegisterMessage(Tag) - TrainingEvent = _reflection.GeneratedProtocolMessageType('TrainingEvent', (_message.Message,), { 'DESCRIPTOR' : _TRAININGEVENT, '__module__' : 'api.protos.trainer.trainer_pb2' @@ -54,14 +38,10 @@ _TRAINER = DESCRIPTOR.services_by_name['Trainer'] if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _TRAININGINPUT._serialized_start=37 - _TRAININGINPUT._serialized_end=174 - _TRAININGDATA._serialized_start=176 - _TRAININGDATA._serialized_end=236 - _TAG._serialized_start=238 - _TAG._serialized_end=284 - _TRAININGEVENT._serialized_start=286 - _TRAININGEVENT._serialized_end=317 - _TRAINER._serialized_start=319 - _TRAINER._serialized_end=381 + _TRAININGINPUT._serialized_start=36 + _TRAININGINPUT._serialized_end=158 + _TRAININGEVENT._serialized_start=160 + _TRAININGEVENT._serialized_end=191 + _TRAINER._serialized_start=193 + _TRAINER._serialized_end=255 # @@protoc_insertion_point(module_scope) diff --git a/microservices/trainer/trainer_pb2_grpc.py b/microservices/trainer/trainer_pb2_grpc.py index ede71aa15bee98983bbbbfdf21125462b25faef2..ca55b8cc9bbbb93e818c21d3e797f954b3d546ee 100644 --- a/microservices/trainer/trainer_pb2_grpc.py +++ b/microservices/trainer/trainer_pb2_grpc.py @@ -1,7 +1,6 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc - import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2