Skip to content
Snippets Groups Projects
Commit da9db883 authored by Julien B.'s avatar Julien B.
Browse files

fix(inferer): fix missing var

parent b6ffc7b9
No related branches found
No related tags found
No related merge requests found
...@@ -125,9 +125,9 @@ def inference_process(inference_data, model_id): ...@@ -125,9 +125,9 @@ def inference_process(inference_data, model_id):
model = RobertaForSpanCategorization.from_pretrained(model_id) model = RobertaForSpanCategorization.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
def get_offsets_and_predicted_tags(example: str, model, tokenizer, threshold=0): def get_offsets_and_predicted_tags(sentence: str, model, tokenizer, threshold=0):
raw_encoded_example = tokenizer(example, return_offsets_mapping=True) raw_encoded_example = tokenizer(sentence, return_offsets_mapping=True)
encoded_example = tokenizer(example, return_tensors="pt") encoded_example = tokenizer(sentence, return_tensors="pt")
out = model(**encoded_example)["logits"][0] out = model(**encoded_example)["logits"][0]
predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out] predicted_tags = [[i for i, l in enumerate(logit) if l > threshold] for logit in out]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment