Skip to content
Snippets Groups Projects
inferer.py 6.40 KiB
import json
from concurrent import futures
import grpc
import torch
from transformers import RobertaPreTrainedModel, RobertaModel, AutoTokenizer
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import (
    ROBERTA_INPUTS_DOCSTRING,
    ROBERTA_START_DOCSTRING,
    RobertaEmbeddings,
)
from transformers.utils import add_start_docstrings_to_model_forward
from torch import nn
from typing import Optional, Union, Tuple
import inferer_pb2_grpc
import inferer_pb2

is_busy = False

class InfererServicer(inferer_pb2_grpc.InfererServicer):
    def StartInference(self, request, context):
        print("event received")
        global is_busy

        if not is_busy:
            is_busy = True
            print(f"incoming request : {request}")
            try:
                result = inference_process(request.inference_data, request.model_id)
                torch.cuda.empty_cache()
                is_busy = False
                return inferer_pb2.InferenceResult(
                    exit_code=0,
                    status="Inference ended successfully !",
                    inference_result=json.dumps(result)
                )
            except Exception as e:
                print(f"Error : {e}")
        else:
            print(f"gRPC server is already busy")

        return inferer_pb2.InferenceResult(
            exit_code=1,
            status="Inference failed !",
            inference_result=""
        )


def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    inferer_pb2_grpc.add_InfererServicer_to_server(InfererServicer(), server)
    server.add_insecure_port('[::]:80')
    server.start()
    server.wait_for_termination()

def inference_process(inference_data, model_id):

    class RobertaForSpanCategorization(RobertaPreTrainedModel):
        _keys_to_ignore_on_load_unexpected = [r"pooler"]
        _keys_to_ignore_on_load_missing = [r"position_ids"]

        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels
            self.roberta = RobertaModel(config, add_pooling_layer=False)
            classifier_dropout = (
                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
            )
            self.dropout = nn.Dropout(classifier_dropout)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.post_init()

        @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
        def forward(
                self,
                input_ids: Optional[torch.LongTensor] = None,
                attention_mask: Optional[torch.FloatTensor] = None,
                token_type_ids: Optional[torch.LongTensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                head_mask: Optional[torch.FloatTensor] = None,
                inputs_embeds: Optional[torch.FloatTensor] = None,
                labels: Optional[torch.LongTensor] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
        ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
            outputs = self.roberta(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            sequence_output = outputs[0]
            sequence_output = self.dropout(sequence_output)
            logits = self.classifier(sequence_output)

            loss = None
            if labels is not None:
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels.float())
            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output
            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )


    model = RobertaForSpanCategorization.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    def get_offsets_and_predicted_tags(example: str, model, tokenizer, threshold=0):
        raw_encoded_example = tokenizer(example, return_offsets_mapping=True)
        encoded_example = tokenizer(example, return_tensors="pt")
        out = model(**encoded_example)["logits"][0]
        predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out]

        return [{"token": token, "tags": tag, "offset": offset} for (token, tag, offset)
                in zip(tokenizer.batch_decode(raw_encoded_example["input_ids"]),
                       predicted_tags,
                       raw_encoded_example["offset_mapping"])]

    def get_tagged_groups(sentence: str):
        offsets_and_tags = get_offsets_and_predicted_tags(sentence, model, tokenizer)
        predicted_offsets = {l: [] for l in tag2id}
        last_token_tags = []
        for item in offsets_and_tags:
            (start, end), tags = item["offset"], item["tags"]

            for label_id in tags:
                tag = id2label[label_id]
                if label_id not in last_token_tags and label2id[f"{tag}"] not in last_token_tags:
                    predicted_offsets[tag].append({"start": start, "end": end})
                else:
                    predicted_offsets[tag][-1]["end"] = end

            last_token_tags = tags

        flatten_predicted_offsets = [{**v, "tag": k, "text": sentence[v["start"]:v["end"]]}
                                     for k, v_list in predicted_offsets.items() for v in v_list if v["end"] - v["start"] >= 3]
        flatten_predicted_offsets = sorted(flatten_predicted_offsets,
                                           key = lambda row: (row["start"], row["end"], row["tag"]))
        return flatten_predicted_offsets

    return get_tagged_groups(inference_data)


if __name__ == '__main__':
    serve()