Skip to content
Snippets Groups Projects
Commit 23e27a4c authored by Julien Rabault's avatar Julien Rabault
Browse files

remove argmax output

parent 36720c72
No related branches found
No related tags found
1 merge request!1Draft: Master
...@@ -161,7 +161,7 @@ class SuperTagger: ...@@ -161,7 +161,7 @@ class SuperTagger:
preds, hidden = self.model.predict((sents_tokenized_t, sents_mask_t)) preds, hidden = self.model.predict((sents_tokenized_t, sents_mask_t))
return preds, self.tags_tokenizer.convert_ids_to_tags(preds.detach()), hidden return preds, self.tags_tokenizer.convert_ids_to_tags(torch.argmax(preds, dim=2).detach()), hidden
def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=32, def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=32,
tensorboard=False, tensorboard=False,
......
...@@ -38,4 +38,4 @@ class Tagging_bert_model(Module): ...@@ -38,4 +38,4 @@ class Tagging_bert_model(Module):
output = self.bert( output = self.bert(
input_ids=b_input_ids, attention_mask=b_input_mask) input_ids=b_input_ids, attention_mask=b_input_mask)
return torch.argmax(output[0], dim=2), output[1][0] return output[0], output[1][0]
...@@ -27,5 +27,18 @@ def load_obj(name): ...@@ -27,5 +27,18 @@ def load_obj(name):
import pickle import pickle
return pickle.load(f) return pickle.load(f)
def categorical_accuracy_str(preds : list[list[float]], truth: list[list[float]]) -> float:
nb_label = 0
good_label = 0
for i in range(len(truth)):
sublist_truth = truth[i]
sublist_preds = preds[i]
for j in range(min(len(sublist_truth), len(sublist_preds))):
if str(sublist_truth[j]) == str(sublist_preds[j]):
good_label += 1
nb_label += 1
return good_label / nb_label
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment