diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py index e160ecb274bd0103189dcf97903ebc40d8bdd6ce..99fd78694a2365f650804ba0b05c54772531fecc 100644 --- a/SuperTagger/SuperTagger.py +++ b/SuperTagger/SuperTagger.py @@ -35,7 +35,7 @@ def output_create_dir(): return training_dir, writer -def categorical_accuracy(preds: list[list[int]], truth: list[list[int]]) -> float: +def categorical_accuracy(preds, truth): """ Calculates how often predictions match argmax labels. @param preds: batch of prediction. (argmax) @@ -97,7 +97,7 @@ class SuperTagger: # region Instanciation - def load_weights(self, model_file: str): + def load_weights(self, model_file): """ Loads an SupperTagger saved with SupperTagger.__checkpoint_save() (during a train) from a file. @@ -133,7 +133,7 @@ class SuperTagger: self.model_load = True self.trainable = True - def create_new_model(self, num_label: int, bert_name: str, index_to_tags: dict): + def create_new_model(self, num_label, bert_name, index_to_tags): """ Instantiation and parameterization of a new bert model @@ -163,7 +163,7 @@ class SuperTagger: # region Usage - def predict(self, sentences) -> (list[list[list[float]]], list[list[str]], Tensor): + def predict(self, sentences): """ Predict and convert sentences in tags (depends on the dictation given when the model was created) @@ -186,7 +186,7 @@ class SuperTagger: return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach()) - def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor): + def forward(self, b_sents_tokenized, b_sents_mask): """ Function used for the linker (same of predict) """ @@ -194,7 +194,7 @@ class SuperTagger: output = self.model.predict((b_sents_tokenized, b_sents_mask)) return output - def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16, + def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=16, tensorboard=False, checkpoint=False): """ @@ -261,8 +261,8 @@ class SuperTagger: # region Private - def __preprocess_data(self, batch_size: int, sentences: list[str], tags: list[list[str]], - validation_rate: float) -> (DataLoader, DataLoader): + def __preprocess_data(self, batch_size, sentences, tags, + validation_rate): """ Create torch dataloader for training @@ -291,7 +291,7 @@ class SuperTagger: training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) return training_dataloader, validation_dataloader - def __train_epoch(self, training_dataloader: DataLoader) -> (float, float, str): + def __train_epoch(self, training_dataloader): """ Train on epoch @@ -335,7 +335,7 @@ class SuperTagger: return epoch_acc, epoch_loss, training_time - def __eval_epoch(self, validation_dataloader: DataLoader) -> (float, float, int): + def __eval_epoch(self, validation_dataloader): """ Validation on epoch diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py index 7c4a2a010627656972c531f8b57118693ac85150..132e55b8882dab5f388afbd4959c89e848852ef5 100644 --- a/SuperTagger/Utils/helpers.py +++ b/SuperTagger/Utils/helpers.py @@ -29,7 +29,7 @@ def load_obj(name): return pickle.load(f) -def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) -> float: +def categorical_accuracy_str(preds, truth): nb_label = 0 good_label = 0 for i in range(len(truth)):