Skip to content
Snippets Groups Projects
Commit 24bb7315 authored by Julien B.'s avatar Julien B.
Browse files

fix(trainer): change trainer function

parent cbc4c653
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,7 @@ import os
from neo4j import GraphDatabase
from api.internal_services.logger import logger
uri = os.getenv('NEO4J_URL', "bolt://ala-neo4j:7687")
uri = os.getenv('NEO4J_URL', "bolt://localhost:7687")
username = os.getenv('NEO4J_USERNAME', "neo4j")
password = os.getenv('NEO4J_PASSWORD', "password")
......@@ -49,6 +48,16 @@ def create_concept_node(tx, concept, id, origin):
concept=concept, id=id, origin=origin)
def create_sentence_node(tx, id):
tx.run('''
CREATE (
n:Sentence {
id: $id
}
)''',
id=id)
def create_next_word_relation(tx, idFrom, idTo):
tx.run('''
MATCH
......@@ -96,16 +105,15 @@ def create_constituent_relation(tx, idFrom, idTo):
idFrom=idFrom, idTo=idTo
)
def create_relation(tx, idFrom, idTo, id, type):
def create_sub_relation(tx, idFrom, idTo):
tx.run('''
MATCH
(a:Concept),
(b:Concept)
WHERE a.id = $idFrom AND b.id = $idTo
CREATE (a)-[r:RELATION {id: $id, type: $type}]->(b)
(w:Word),
(s:Sentence)
WHERE w.id = $idFrom AND s.id = $idTo
CREATE (w)-[r:SUB]->(s)
''',
idFrom=idFrom, idTo=idTo, id=id, type=type
idFrom=idFrom, idTo=idTo
)
......@@ -120,7 +128,8 @@ def get_filtered_annotation(sentence_id, concept, annotation):
words_ids = get_id_single_tokens(annotation, sentence_id)
if len(words_ids) == 0:
logger.warn(f"Cannot find the following annotation '{annotation}' in the sentence id {sentence_id}. This error is a hallucination of large language models.")
logger.warn(
f"Cannot find the following annotation '{annotation}' in the sentence id {sentence_id}. This error is a hallucination of large language models.")
return set()
filtered_annotation = set()
......@@ -150,8 +159,9 @@ def get_id_multi_tokens(annotation, sentence_id):
'''
WITH $array AS words
MATCH path = (start:Word)-[:NEXT*]->(end:Word)
MATCH (start)--(s:Sentence)
where size(words) - 1 = size(relationships(path))
and start.id starts with $sentence_id
and s.id = $sentence_id
and all(
idx IN range(0, size(words)-2)
WHERE (toLower(words[idx]) = toLower((nodes(path)[idx]).text)
......@@ -164,7 +174,7 @@ def get_id_multi_tokens(annotation, sentence_id):
return collect(results.id) as liste
''',
array=annotation,
sentence_id=f"{sentence_id}."
sentence_id=sentence_id
)
return list_multi_token.records[0][0]
......@@ -172,14 +182,14 @@ def get_id_multi_tokens(annotation, sentence_id):
def get_id_single_tokens(annotation, sentence_id):
list_single_token = driver.execute_query(
'''
match (w:Word)
match (s:Sentence)--(w:Word)
where toLower(w.text) = $annotation
and w.id starts with $sentence_id
and s.id = $sentence_id
with distinct w as results
return collect(results.id) as liste
''',
annotation=annotation,
sentence_id=f"{sentence_id}."
sentence_id=sentence_id
)
return list_single_token.records[0][0]
......@@ -329,3 +339,73 @@ def reference(words_ids):
array=list(words_ids)
)
return set([record[0] for record in nodes.records])
def get_sentences_for_train():
nodes = driver.execute_query(
'''
match (s:Sentence)--(:Word)--(c:Concept)
where c.origin <> "BERT"
return s.id as sentence, collect( distinct c.id) as concepts
'''
)
return [(record[0],record[1]) for record in nodes.records]
def get_full_sentence_by_id(sentence_id):
nodes = driver.execute_query(
'''
MATCH (s:Sentence{id:$sentence_id})--(w1:Word)
MATCH (s:Sentence{id:$sentence_id})--(w2:Word)
WHERE NOT EXISTS {
MATCH (:Word)-[:NEXT]->(w1)
}
and not exists {
match (w2)-[:NEXT]->(:Word)
}
WITH w1 AS startNode, w2 as endNode
MATCH path = (startNode)-[:NEXT*]->(endNode)
WITH nodes(path) AS nodelist
WITH [node IN nodelist | node.text] AS texts
WITH reduce(acc = "", t IN texts | acc + " " + t) AS concatenated_text
RETURN concatenated_text
LIMIT 1
''',
sentence_id=sentence_id
)
return [record[0] for record in nodes.records]
def get_concept_text_by_id(concept_id):
nodes = driver.execute_query('''
match (c:Concept{id:$concept_id})--(w:Word)
with collect(w) as word_list
match (startNode:Word)
match (endNode:Word)
where startNode in word_list
and endNode in word_list
and not exists {
match (w1:Word)-[:NEXT]->(startNode)
where w1 in word_list
}
and not exists {
match (endNode)-[:NEXT]->(w2:Word)
where w2 in word_list
}
with startNode, endNode
MATCH path = (startNode)-[:NEXT*]->(endNode)
WITH nodes(path) AS nodelist
WITH [node IN nodelist | node.text] AS texts
WITH reduce(acc = "", t IN texts | acc + " " + t) AS concatenated_text
RETURN concatenated_text
LIMIT 1
''',
concept_id=concept_id
)
return [record[0] for record in nodes.records]
def get_concept_type_by_id(concept_id):
nodes = driver.execute_query('''
match (c:Concept{id:$concept_id})
return toLower(c.type)
''',
concept_id=concept_id
)
return [record[0] for record in nodes.records]
\ No newline at end of file
import uuid
import benepar, spacy
import warnings
......@@ -8,7 +10,7 @@ warnings.filterwarnings("ignore")
from api.internal_services.database import get_last_sentence_index, update_last_sentence_index
from api.internal_services.neo4j import create_constituent_node, driver, create_constituent_relation, create_word_node, \
create_next_word_relation, create_deprel_relation
create_next_word_relation, create_deprel_relation, create_sentence_node, create_sub_relation
benepar.download('benepar_fr2')
nlp = spacy.load('fr_dep_news_trf')
......@@ -17,14 +19,21 @@ nlp.add_pipe('benepar', config={'model': 'benepar_fr2'})
def parsing_and_load_in_neo4j(job):
sentence = job.job_data['sentence']
last_index = get_last_sentence_index()
last_index = update_last_sentence_index(last_index + 1)
last_index = update_last_sentence_index(get_last_sentence_index() + 1)
with (driver.session() as session):
doc = nlp(sentence)
for i, sentence in enumerate(doc.sents):
sentence_uuid = str(uuid.uuid4())
#Création du noeud sentence
session.execute_write(
create_sentence_node,
sentence_uuid
)
for constituent in sentence._.constituents:
constituent_id = f"{last_index}.{i}.{constituent.start}-{constituent.end}"
if constituent._.labels and constituent.root.text != constituent.text:
......@@ -44,6 +53,16 @@ def parsing_and_load_in_neo4j(job):
)
else:
#Création du mot en noeud neo4j
session.execute_write(
create_word_node,
'.'.join(map(str, [last_index, i, constituent.root.i])),
constituent.text,
None if not hasattr(constituent, 'lemma_') else constituent.lemma_,
constituent.root.pos_,
True if constituent.root.dep_ == "root" else False
)
# Créer le mot et le constituant solitaire si nécessaire
if constituent._.labels:
# Créer le consituant
......@@ -53,16 +72,6 @@ def parsing_and_load_in_neo4j(job):
constituent._.labels[0]
)
#Création du mot en noeud neo4j
session.execute_write(
create_word_node,
'.'.join(map(str, [last_index, i, constituent.root.i])),
constituent.text,
None if not hasattr(constituent, 'lemma_') else constituent.lemma_,
constituent.root.pos_,
True if constituent.root.dep_ == "root" else False
)
session.execute_write(
create_constituent_relation,
f"{last_index}.{i}.{constituent.start}-{constituent.end}",
......@@ -76,16 +85,6 @@ def parsing_and_load_in_neo4j(job):
)
else:
#Création du mot en noeud neo4j
session.execute_write(
create_word_node,
'.'.join(map(str, [last_index, i, constituent.root.i])),
constituent.text,
None if not hasattr(constituent, 'lemma_') else constituent.lemma_,
constituent.root.pos_,
True if constituent.root.dep_ == "root" else False
)
# parent existe alors on crée le lien
session.execute_write(
create_constituent_relation,
......@@ -93,6 +92,12 @@ def parsing_and_load_in_neo4j(job):
'.'.join(map(str, [last_index, i, constituent.root.i])),
)
session.execute_write(
create_sub_relation,
'.'.join(map(str, [last_index, i, constituent.root.i])),
sentence_uuid,
)
for token in sentence:
#Création d'un lien de succession
if token.i != 0:
......@@ -108,7 +113,7 @@ def parsing_and_load_in_neo4j(job):
new_job = Job()
new_job.job_id = job.job_id
new_job.job_type = JobType.ANNOTATION
new_job.job_data = {'sentence': sentence, 'sentence_id': last_index}
new_job.job_data = {'sentence': sentence, 'sentence_id': sentence_uuid}
add_job_to_queue(new_job)
......
import json
import grpc
from api.internal_services import neo4j
from api.internal_services.logger import logger
from api.protos.trainer import trainer_pb2_grpc, trainer_pb2
def find_all_occurrences(text, phrase):
start = 0
while True:
start = text.find(phrase, start)
if start == -1: return
yield start
start += len(phrase)
def start_training(job):
training_data = []
training_sentences = neo4j.get_sentences_for_train()
for tuple in training_sentences:
sentence_id = tuple[0]
concepts_id = tuple[1]
sentence = neo4j.get_full_sentence_by_id(sentence_id)[0].strip()
tags = []
for concept_id in concepts_id:
span = neo4j.get_concept_text_by_id(concept_id)[0].strip()
type = neo4j.get_concept_type_by_id(concept_id)[0]
all_occurrences = list(find_all_occurrences(sentence, span))
for start_index in all_occurrences:
end_index = start_index + len(span)
tags.append({
"start": start_index if start_index == 0 else start_index-1,
"end": end_index,
"tag": type,
})
training_data.append({
'id': sentence_id,
'text': sentence,
'tags': tags
})
logger.debug(training_data)
with grpc.insecure_channel(job.job_data['server_url']) as channel:
stub = trainer_pb2_grpc.TrainerStub(channel)
request = trainer_pb2.TrainingInput(
training_data=[],
training_data=json.dumps(training_data),
fondation_model_id=job.job_data['fondation_model_id'],
finetuned_repo_name=job.job_data['finetuned_repo_name'],
huggingface_token=job.job_data['huggingface_token'],
......@@ -16,4 +56,6 @@ def start_training(job):
responses = stub.StartTraining(request)
for response in responses:
logger.debug(f"gRPC message : {response.status}")
logger.debug(f"fin de la connexion gRPC")
\ No newline at end of file
logger.debug(f"fin de la connexion gRPC")
# passer toutes les origines des concepts en BERT
......@@ -5,24 +5,12 @@ service Trainer {
}
message TrainingInput {
repeated TrainingData training_data = 1;
string training_data = 1;
string fondation_model_id = 2;
string finetuned_repo_name = 3;
string huggingface_token = 4;
}
message TrainingData {
string id = 1;
string text = 2;
repeated Tag tags = 3;
}
message Tag {
int32 start = 1;
int32 end = 2;
string tag = 3;
}
message TrainingEvent {
string status = 1;
}
\ No newline at end of file
......@@ -14,13 +14,11 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"\x89\x01\n\rTrainingInput\x12$\n\rtraining_data\x18\x01 \x03(\x0b\x32\r.TrainingData\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"<\n\x0cTrainingData\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x12\n\x04tags\x18\x03 \x03(\x0b\x32\x04.Tag\".\n\x03Tag\x12\r\n\x05start\x18\x01 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x02 \x01(\x05\x12\x0b\n\x03tag\x18\x03 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
_TRAININGINPUT = DESCRIPTOR.message_types_by_name['TrainingInput']
_TRAININGDATA = DESCRIPTOR.message_types_by_name['TrainingData']
_TAG = DESCRIPTOR.message_types_by_name['Tag']
_TRAININGEVENT = DESCRIPTOR.message_types_by_name['TrainingEvent']
TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_message.Message,), {
'DESCRIPTOR' : _TRAININGINPUT,
......@@ -29,20 +27,6 @@ TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_mess
})
_sym_db.RegisterMessage(TrainingInput)
TrainingData = _reflection.GeneratedProtocolMessageType('TrainingData', (_message.Message,), {
'DESCRIPTOR' : _TRAININGDATA,
'__module__' : 'api.protos.trainer.trainer_pb2'
# @@protoc_insertion_point(class_scope:TrainingData)
})
_sym_db.RegisterMessage(TrainingData)
Tag = _reflection.GeneratedProtocolMessageType('Tag', (_message.Message,), {
'DESCRIPTOR' : _TAG,
'__module__' : 'api.protos.trainer.trainer_pb2'
# @@protoc_insertion_point(class_scope:Tag)
})
_sym_db.RegisterMessage(Tag)
TrainingEvent = _reflection.GeneratedProtocolMessageType('TrainingEvent', (_message.Message,), {
'DESCRIPTOR' : _TRAININGEVENT,
'__module__' : 'api.protos.trainer.trainer_pb2'
......@@ -54,14 +38,10 @@ _TRAINER = DESCRIPTOR.services_by_name['Trainer']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_TRAININGINPUT._serialized_start=37
_TRAININGINPUT._serialized_end=174
_TRAININGDATA._serialized_start=176
_TRAININGDATA._serialized_end=236
_TAG._serialized_start=238
_TAG._serialized_end=284
_TRAININGEVENT._serialized_start=286
_TRAININGEVENT._serialized_end=317
_TRAINER._serialized_start=319
_TRAINER._serialized_end=381
_TRAININGINPUT._serialized_start=36
_TRAININGINPUT._serialized_end=158
_TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193
_TRAINER._serialized_end=255
# @@protoc_insertion_point(module_scope)
......@@ -5,24 +5,12 @@ service Trainer {
}
message TrainingInput {
repeated TrainingData training_data = 1;
string training_data = 1;
string fondation_model_id = 2;
string finetuned_repo_name = 3;
string huggingface_token = 4;
}
message TrainingData {
string id = 1;
string text = 2;
repeated Tag tags = 3;
}
message Tag {
int32 start = 1;
int32 end = 2;
string tag = 3;
}
message TrainingEvent {
string status = 1;
}
\ No newline at end of file
import json
from concurrent import futures
from typing import Optional, Union, Tuple
import grpc
......@@ -55,6 +56,8 @@ def serve():
def training_process(training_data, fondation_model_id, finetuned_repo_name, huggingface_token):
training_data = json.loads(training_data)
print(training_data)
MAX_LENGTH = 256
tag2id = {'action': 1, 'actor': 2, 'artifact': 3, 'condition': 4, 'location': 5, 'modality': 6, 'reference': 7,
'time': 8}
......
......@@ -14,13 +14,11 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"\x89\x01\n\rTrainingInput\x12$\n\rtraining_data\x18\x01 \x03(\x0b\x32\r.TrainingData\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"<\n\x0cTrainingData\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x12\n\x04tags\x18\x03 \x03(\x0b\x32\x04.Tag\".\n\x03Tag\x12\r\n\x05start\x18\x01 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x02 \x01(\x05\x12\x0b\n\x03tag\x18\x03 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n api/protos/trainer/trainer.proto\"z\n\rTrainingInput\x12\x15\n\rtraining_data\x18\x01 \x01(\t\x12\x1a\n\x12\x66ondation_model_id\x18\x02 \x01(\t\x12\x1b\n\x13\x66inetuned_repo_name\x18\x03 \x01(\t\x12\x19\n\x11huggingface_token\x18\x04 \x01(\t\"\x1f\n\rTrainingEvent\x12\x0e\n\x06status\x18\x01 \x01(\t2>\n\x07Trainer\x12\x33\n\rStartTraining\x12\x0e.TrainingInput\x1a\x0e.TrainingEvent\"\x00\x30\x01\x62\x06proto3')
_TRAININGINPUT = DESCRIPTOR.message_types_by_name['TrainingInput']
_TRAININGDATA = DESCRIPTOR.message_types_by_name['TrainingData']
_TAG = DESCRIPTOR.message_types_by_name['Tag']
_TRAININGEVENT = DESCRIPTOR.message_types_by_name['TrainingEvent']
TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_message.Message,), {
'DESCRIPTOR' : _TRAININGINPUT,
......@@ -29,20 +27,6 @@ TrainingInput = _reflection.GeneratedProtocolMessageType('TrainingInput', (_mess
})
_sym_db.RegisterMessage(TrainingInput)
TrainingData = _reflection.GeneratedProtocolMessageType('TrainingData', (_message.Message,), {
'DESCRIPTOR' : _TRAININGDATA,
'__module__' : 'api.protos.trainer.trainer_pb2'
# @@protoc_insertion_point(class_scope:TrainingData)
})
_sym_db.RegisterMessage(TrainingData)
Tag = _reflection.GeneratedProtocolMessageType('Tag', (_message.Message,), {
'DESCRIPTOR' : _TAG,
'__module__' : 'api.protos.trainer.trainer_pb2'
# @@protoc_insertion_point(class_scope:Tag)
})
_sym_db.RegisterMessage(Tag)
TrainingEvent = _reflection.GeneratedProtocolMessageType('TrainingEvent', (_message.Message,), {
'DESCRIPTOR' : _TRAININGEVENT,
'__module__' : 'api.protos.trainer.trainer_pb2'
......@@ -54,14 +38,10 @@ _TRAINER = DESCRIPTOR.services_by_name['Trainer']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_TRAININGINPUT._serialized_start=37
_TRAININGINPUT._serialized_end=174
_TRAININGDATA._serialized_start=176
_TRAININGDATA._serialized_end=236
_TAG._serialized_start=238
_TAG._serialized_end=284
_TRAININGEVENT._serialized_start=286
_TRAININGEVENT._serialized_end=317
_TRAINER._serialized_start=319
_TRAINER._serialized_end=381
_TRAININGINPUT._serialized_start=36
_TRAININGINPUT._serialized_end=158
_TRAININGEVENT._serialized_start=160
_TRAININGEVENT._serialized_end=191
_TRAINER._serialized_start=193
_TRAINER._serialized_end=255
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import trainer_pb2 as api_dot_protos_dot_trainer_dot_trainer__pb2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment