Skip to content
Snippets Groups Projects
annotator.py 4.16 KiB
import json
import grpc
from api.internal_services import neo4j, database
from api.internal_services.database import update_last_concept_index, get_last_concept_index
from api.internal_services.gpt import gpt_process
from api.internal_services.logger import logger
from api.protos.inferer import inferer_pb2_grpc, inferer_pb2

is_using_GPT = True


def annotation_process(job):
    annotator_config = database.get_annotator_config()
    match annotator_config['provider']:
        case 'GPT':
            from api.internal_services.neo4j import driver, create_concept_node, create_concept_relation
            with (driver.session() as session):
                gpt_annotation = gpt_process(job.job_data['sentence'].text)
                for concept, annotations in gpt_annotation.items():
                    filtered_annotation = set()
                    for annotation in annotations:
                        filtered_annotation = filtered_annotation | neo4j.get_filtered_annotation(
                            job.job_data['sentence_id'], concept, annotation)
                    if filtered_annotation:
                        for interval in separate_intervals(filtered_annotation):
                            session.execute_write(
                                create_concept_node,
                                concept,
                                update_last_concept_index(get_last_concept_index() + 1),
                                False
                            )

                            for word_id in interval:
                                session.execute_write(
                                    create_concept_relation,
                                    get_last_concept_index(),
                                    word_id
                                )

        case 'BERT':
            response: inferer_pb2.InferenceResult = None
            with grpc.insecure_channel(annotator_config['server_url']) as channel:
                stub = inferer_pb2_grpc.InfererStub(channel)
                request = inferer_pb2.InferenceInput(
                    inference_data=job.job_data['sentence'].text,
                    model_id=annotator_config['model_id'],
                )
                response = stub.StartInference(request)
                logger.debug(f"Incoming gRPC message : {response.status}")
            logger.debug(f"fin de la connexion gRPC")

            response = json.loads(response.inference_result)
            from api.internal_services.neo4j import driver, get_id_multi_tokens, get_id_single_tokens, create_concept_node, create_concept_relation
            with (driver.session() as session):
                for span in response:
                    if " " in span:
                        word_ids = get_id_multi_tokens(span['text'], job.job_data['sentence_id'])
                    else:
                        word_ids = get_id_single_tokens(span['text'], job.job_data['sentence_id'])

                    session.execute_write(
                        create_concept_node,
                        span['tag'],
                        update_last_concept_index(get_last_concept_index() + 1),
                        False
                    )

                    for word_id in word_ids:
                        session.execute_write(
                            create_concept_relation,
                            get_last_concept_index(),
                            word_id
                        )
        case 'NONE':
            logger.error("No annotator configured, please provide an annotator before trying to add sentences.")


def separate_intervals(data):
    sorted_data = sorted(list(data), key=lambda x: int(x.split('.')[-1]))

    separated_intervals = []
    current_interval = [sorted_data[0]]

    for i in range(1, len(sorted_data)):
        current_value = int(sorted_data[i].split('.')[-1])
        previous_value = int(current_interval[-1].split('.')[-1])

        if current_value == previous_value + 1:
            current_interval.append(sorted_data[i])
        else:
            separated_intervals.append(current_interval)
            current_interval = [sorted_data[i]]

    separated_intervals.append(current_interval)

    return separated_intervals