From 4963a6b33a361caf3dd8c458fead6b8352b034cd Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <caroline.de-pourtales@irit.fr>
Date: Fri, 15 Jul 2022 11:49:55 +0200
Subject: [PATCH] deleting typing

---
 SuperTagger/SuperTagger.py   | 20 ++++++++++----------
 SuperTagger/Utils/helpers.py |  2 +-
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/SuperTagger/SuperTagger.py b/SuperTagger/SuperTagger.py
index e160ecb..99fd786 100644
--- a/SuperTagger/SuperTagger.py
+++ b/SuperTagger/SuperTagger.py
@@ -35,7 +35,7 @@ def output_create_dir():
     return training_dir, writer
 
 
-def categorical_accuracy(preds: list[list[int]], truth: list[list[int]]) -> float:
+def categorical_accuracy(preds, truth):
     """
     Calculates how often predictions match argmax labels.
     @param preds: batch of prediction. (argmax)
@@ -97,7 +97,7 @@ class SuperTagger:
 
     # region Instanciation
 
-    def load_weights(self, model_file: str):
+    def load_weights(self, model_file):
         """
         Loads an SupperTagger saved with SupperTagger.__checkpoint_save() (during a train) from a file.
 
@@ -133,7 +133,7 @@ class SuperTagger:
         self.model_load = True
         self.trainable = True
 
-    def create_new_model(self, num_label: int, bert_name: str, index_to_tags: dict):
+    def create_new_model(self, num_label, bert_name, index_to_tags):
         """
         Instantiation and parameterization of a new bert model
 
@@ -163,7 +163,7 @@ class SuperTagger:
 
     # region Usage
 
-    def predict(self, sentences) -> (list[list[list[float]]], list[list[str]], Tensor):
+    def predict(self, sentences):
         """
         Predict and convert sentences in tags (depends on the dictation given when the model was created)
 
@@ -186,7 +186,7 @@ class SuperTagger:
 
             return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach())
 
-    def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor):
+    def forward(self, b_sents_tokenized, b_sents_mask):
         """
         Function used for the linker (same of predict)
         """
@@ -194,7 +194,7 @@ class SuperTagger:
             output = self.model.predict((b_sents_tokenized, b_sents_mask))
             return output
 
-    def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16,
+    def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=16,
               tensorboard=False,
               checkpoint=False):
         """
@@ -261,8 +261,8 @@ class SuperTagger:
 
     # region Private
 
-    def __preprocess_data(self, batch_size: int, sentences: list[str], tags: list[list[str]],
-                          validation_rate: float) -> (DataLoader, DataLoader):
+    def __preprocess_data(self, batch_size, sentences, tags,
+                          validation_rate):
         """
         Create torch dataloader for training
 
@@ -291,7 +291,7 @@ class SuperTagger:
         training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
         return training_dataloader, validation_dataloader
 
-    def __train_epoch(self, training_dataloader: DataLoader) -> (float, float, str):
+    def __train_epoch(self, training_dataloader):
         """
         Train on epoch
 
@@ -335,7 +335,7 @@ class SuperTagger:
 
         return epoch_acc, epoch_loss, training_time
 
-    def __eval_epoch(self, validation_dataloader: DataLoader) -> (float, float, int):
+    def __eval_epoch(self, validation_dataloader):
         """
         Validation on epoch
 
diff --git a/SuperTagger/Utils/helpers.py b/SuperTagger/Utils/helpers.py
index 7c4a2a0..132e55b 100644
--- a/SuperTagger/Utils/helpers.py
+++ b/SuperTagger/Utils/helpers.py
@@ -29,7 +29,7 @@ def load_obj(name):
         return pickle.load(f)
 
 
-def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) -> float:
+def categorical_accuracy_str(preds, truth):
     nb_label = 0
     good_label = 0
     for i in range(len(truth)):
-- 
GitLab