Skip to content
Snippets Groups Projects
Commit 4963a6b3 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

deleting typing

parent 5e0235c6
Branches
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ def output_create_dir():
return training_dir, writer
def categorical_accuracy(preds: list[list[int]], truth: list[list[int]]) -> float:
def categorical_accuracy(preds, truth):
"""
Calculates how often predictions match argmax labels.
@param preds: batch of prediction. (argmax)
......@@ -97,7 +97,7 @@ class SuperTagger:
# region Instanciation
def load_weights(self, model_file: str):
def load_weights(self, model_file):
"""
Loads an SupperTagger saved with SupperTagger.__checkpoint_save() (during a train) from a file.
......@@ -133,7 +133,7 @@ class SuperTagger:
self.model_load = True
self.trainable = True
def create_new_model(self, num_label: int, bert_name: str, index_to_tags: dict):
def create_new_model(self, num_label, bert_name, index_to_tags):
"""
Instantiation and parameterization of a new bert model
......@@ -163,7 +163,7 @@ class SuperTagger:
# region Usage
def predict(self, sentences) -> (list[list[list[float]]], list[list[str]], Tensor):
def predict(self, sentences):
"""
Predict and convert sentences in tags (depends on the dictation given when the model was created)
......@@ -186,7 +186,7 @@ class SuperTagger:
return output['logit'], self.tags_tokenizer.convert_ids_to_tags(torch.argmax(output['logit'], dim=2).detach())
def forward(self, b_sents_tokenized: Tensor, b_sents_mask: Tensor) -> (Tensor, Tensor):
def forward(self, b_sents_tokenized, b_sents_mask):
"""
Function used for the linker (same of predict)
"""
......@@ -194,7 +194,7 @@ class SuperTagger:
output = self.model.predict((b_sents_tokenized, b_sents_mask))
return output
def train(self, sentences: list[str], tags: list[list[str]], validation_rate=0.1, epochs=20, batch_size=16,
def train(self, sentences, tags, validation_rate=0.1, epochs=20, batch_size=16,
tensorboard=False,
checkpoint=False):
"""
......@@ -261,8 +261,8 @@ class SuperTagger:
# region Private
def __preprocess_data(self, batch_size: int, sentences: list[str], tags: list[list[str]],
validation_rate: float) -> (DataLoader, DataLoader):
def __preprocess_data(self, batch_size, sentences, tags,
validation_rate):
"""
Create torch dataloader for training
......@@ -291,7 +291,7 @@ class SuperTagger:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
return training_dataloader, validation_dataloader
def __train_epoch(self, training_dataloader: DataLoader) -> (float, float, str):
def __train_epoch(self, training_dataloader):
"""
Train on epoch
......@@ -335,7 +335,7 @@ class SuperTagger:
return epoch_acc, epoch_loss, training_time
def __eval_epoch(self, validation_dataloader: DataLoader) -> (float, float, int):
def __eval_epoch(self, validation_dataloader):
"""
Validation on epoch
......
......@@ -29,7 +29,7 @@ def load_obj(name):
return pickle.load(f)
def categorical_accuracy_str(preds: list[list[float]], truth: list[list[float]]) -> float:
def categorical_accuracy_str(preds, truth):
nb_label = 0
good_label = 0
for i in range(len(truth)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment