From 4963a6b33a361caf3dd8c458fead6b8352b034cd Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <caroline.de-pourtales@irit.fr> Date: Fri, 15 Jul 2022 11:49:55 +0200 Subject: [PATCH] deleting typing --- SuperTagger/SuperTagger.py | 20 ++++++++++---------- SuperTagger/Utils/helpers.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index e160ecb..99fd786 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -35,7 +35,7 @@ def output_create_dir(): return training_dir, writer -def categorical_accuracy(preds: list[list[int]], truth: list[list[int]]) -> float: +def categorical_accuracy(preds, truth): """ Calculates how often predictions match argmax labels. @param preds: batch of prediction. (argmax) @@ -97,7 +97,7 @@ class SuperTagger: # region Instanciation - def load_weights(self, model_file: str): + def load_weights(self, model_file): """ Loads an SupperTagger saved with SupperTagger.__checkpoint_save() (during a train) from a file. @@ -133,7 +133,7 @@ class SuperTagger: self.model_load = True self.trainable = True - def create_new_model(self, num_label: int, bert_name: str, index_to_tags: dict): + def create_new_model(self, num_label, bert_name, index_to_tags): """ Instantiation and parameterization of a new bert model @@ -163,7 +163,7 @@ class SuperTagger: # region Usage - def predict(self, sentences) -> (list[list[list[float]]], list[list[str]], Tensor): + def predict(self, sentences): """ Predict and convert sentences in tags (depends on the dictation given when the model was created) @@ -186,7 +186,7 @@ class SuperTagger: return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach()) - def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor): + def forward(self, b_sents_tokenized, b_sents_mask): """ Function used for the linker (same of predict) """ @@ -194,7 +194,7 @@ class SuperTagger: output = self.model.predict((b_sents_tokenized, b_sents_mask)) return output - def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16, + def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=16, tensorboard=False, checkpoint=False): """ @@ -261,8 +261,8 @@ class SuperTagger: # region Private - def __preprocess_data(self, batch_size: int, sentences: list[str], tags: list[list[str]], - validation_rate: float) -> (DataLoader, DataLoader): + def __preprocess_data(self, batch_size, sentences, tags, + validation_rate): """ Create torch dataloader for training @@ -291,7 +291,7 @@ class SuperTagger: training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) return training_dataloader, validation_dataloader - def __train_epoch(self, training_dataloader: DataLoader) -> (float, float, str): + def __train_epoch(self, training_dataloader): """ Train on epoch @@ -335,7 +335,7 @@ class SuperTagger: return epoch_acc, epoch_loss, training_time - def __eval_epoch(self, validation_dataloader: DataLoader) -> (float, float, int): + def __eval_epoch(self, validation_dataloader): """ Validation on epoch diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py index 7c4a2a0..132e55b 100644 --- a/SuperTagger/Utils/helpers.py +++ b/SuperTagger/Utils/helpers.py @@ -29,7 +29,7 @@ def load_obj(name): return pickle.load(f) -def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) -> float: +def categorical_accuracy_str(preds, truth): nb_label = 0 good_label = 0 for i in range(len(truth)): -- GitLab