diff --git a/.gitignore b/.gitignore index afa9b70f86aff397f5ddd3837824540602e8642e..557207b998a02ae4e19ee855a069b4f1d97c3a6e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ backup_results/* *.sh tests/__pycache__/* datasets/* -pub/* \ No newline at end of file +pub/* +data/* diff --git a/driver.py b/driver.py index e28c4227edf8afa08454adab8dd8d5ac34b15c02..9332a4cb0671b019de3fd7c9a1857785b6e6d42a 100644 --- a/driver.py +++ b/driver.py @@ -28,7 +28,7 @@ def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, nu output_name = row_exp.to_string(header=False, index=False, name=False).replace(' ', "").replace('\n','_') - + hash_outputname = get_uid(output_name) pathlist = Path("results").rglob('*.json') @@ -47,7 +47,6 @@ def main_driver(exp_type, dataset, nn_model, heterogeneity_type, num_clients, nu except Exception as e: print(f"Could not run experiment with parameters {output_name}. Exception {e}") - return launch_experiment(model_server, list_clients, row_exp, output_name) @@ -94,7 +93,5 @@ def launch_experiment(model_server, list_clients, row_exp, output_name, save_res return - - if __name__ == "__main__": main_driver() diff --git a/src/models.py b/src/models.py index 3acd1e15c7474661de12c4efb8b68abe8e336422..78387938fac081a9759143cae170198c354f5c47 100644 --- a/src/models.py +++ b/src/models.py @@ -3,86 +3,89 @@ import torch.nn as nn import torch.nn.functional as F -class SimpleLinear(nn.Module): - """ Fully connected neural network with a single hidden layer of default size 200 and ReLU activations""" +def accuracy(outputs, labels): + _, preds = torch.max(outputs, dim=1) + return torch.tensor(torch.sum(preds == labels).item() / len(preds)) + + +class ImageClassificationBase(nn.Module): + def training_step(self, batch): + images, labels = batch + out = self(images) + loss = F.cross_entropy(out, labels) # Calculate loss + return loss + + def validation_step(self, batch): + images, labels = batch + out = self(images) + loss = F.cross_entropy(out, labels) # Calculate loss + acc = accuracy(out, labels) + return {'val_loss': loss.detach(), 'val_acc': acc} + + def validation_epoch_end(self, outputs): + batch_losses = [x['val_loss'] for x in outputs] + epoch_loss = torch.stack(batch_losses).mean() + batch_accs = [x['val_acc'] for x in outputs] + epoch_acc = torch.stack(batch_accs).mean() + return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} + + def epoch_end(self, epoch, result): + print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( + epoch, result['train_loss'], result['val_loss'], result['val_acc'])) + + + +class GenericLinearModel(ImageClassificationBase): def __init__(self, in_size, n_channels): - """ Initialization function - Arguments: - h1: int - Desired size of the hidden layer - """ super().__init__() - self.fc1 = nn.Linear(in_size*in_size,200) - self.fc2 = nn.Linear(200, 10) + self.in_size = in_size - def forward(self, x: torch.Tensor): + self.network = nn.Sequential( + nn.Linear(in_size*in_size,200), + nn.Linear(200, 10)) - """ Forward pass function through the network - - Arguments: - x : torch.Tensor - input image of size in_size x in_size - - Returns: - log_softmax probabilities of the output layer - """ + def forward(self, xb): + xb = xb.view(-1, self.in_size * self.in_size) + return self.network(xb) - x = x.view(-1, self.in_size * self.in_size) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) -class SimpleConv(nn.Module): - """ Convolutional neural network with 3 convolutional layers and one fully connected layer - """ - - def __init__(self, in_size, n_channels): - """ Initialization function - """ - super(SimpleConv, self).__init__() - - self.conv1 = nn.Conv2d(n_channels, 16, 3, padding=1) - self.conv2 = nn.Conv2d(16, 32, 3, padding=1) - self.conv3 = nn.Conv2d(32, 16, 3, padding=1) - - self.img_final_size = int(in_size / 8) +class GenericConvModel(ImageClassificationBase): + def __init__(self, in_size, n_channels): + super().__init__() - self.fc1 = nn.Linear(16 * self.img_final_size * self.img_final_size, 10) - - self.pool = nn.MaxPool2d(2, 2) + self.img_final_size = int(in_size / (2**3)) - self.dropout = nn.Dropout(p=0.2) + self.network = nn.Sequential( + nn.Conv2d(n_channels, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(2, 2), # output: 64 x 16 x 16 - def flatten(self, x : torch.Tensor): - - """Function to flatten a layer - - Arguments: - x : torch.Tensor + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(2, 2), # output: 128 x 8 x 8 - Returns: - flattened Tensor - """ - - return x.reshape(x.size()[0], -1) - - def forward(self, x : torch.Tensor): - """ Forward pass through the network which returns the softmax probabilities of the output layer + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(2, 2), # output: 256 x 4 x 4 - Arguments: - x : torch.Tensor - input image to use for training - """ + nn.Flatten(), + nn.Linear(256 * self.img_final_size * self.img_final_size, 1024), + nn.ReLU(), + nn.Linear(1024, 512), + nn.ReLU(), + nn.Linear(512, 10)) - x = self.dropout(self.pool(F.relu(self.conv1(x)))) - x = self.dropout(self.pool(F.relu(self.conv2(x)))) - x = self.dropout(self.pool(F.relu(self.conv3(x)))) - x = self.flatten(x) - x = self.fc1(x) - - return F.log_softmax(x, dim=1) \ No newline at end of file + def forward(self, xb): + return self.network(xb) + \ No newline at end of file diff --git a/src/utils_data.py b/src/utils_data.py index 8cee10f92ea6cafb0c6f778fa0e3058832c5da4d..eea61c59be29003b17ae65f354064dd394628076 100644 --- a/src/utils_data.py +++ b/src/utils_data.py @@ -47,31 +47,37 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: import numpy as np import torchvision from extra_keras_datasets import kmnist - + import torchvision.transforms as transforms + + transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + if dataset == "fashion-mnist": - fashion_mnist = torchvision.datasets.MNIST("datasets", download=True) - (x_train, y_train) = fashion_mnist.data, fashion_mnist.targets + fashion_mnist = torchvision.datasets.MNIST("datasets", download=True, transform=transform) + (x_data, y_data) = fashion_mnist.data, fashion_mnist.targets if nn_model == "convolutional": - x_train = x_train.unsqueeze(1) + x_data = x_data.unsqueeze(1) elif dataset == 'mnist': mnist = torchvision.datasets.MNIST("datasets", download=True) - (x_train, y_train) = mnist.data, mnist.targets + (x_data, y_data) = mnist.data, mnist.targets if nn_model == "convolutional": - x_train = x_train.unsqueeze(1) + x_data = x_data.unsqueeze(1) elif dataset == "cifar10": - cifar10 = torchvision.datasets.CIFAR10("datasets", download=True) - (x_train, y_train) = cifar10.data, cifar10.targets - x_train = np.transpose(x_train, (0, 3, 1, 2)) - + cifar10 = torchvision.datasets.CIFAR10("datasets", download=True, transform=transform) + (x_data, y_data) = cifar10.data, cifar10.targets + x_data = np.transpose(x_data, (0, 3, 1, 2)) + elif dataset == 'kmnist': - (x_train, y_train), _ = kmnist.load_data() + kmnist = torchvision.datasets.KMNIST("datasets", download=True, transform=transform) + (x_data, y_data) = kmnist.load_data() if nn_model == "convolutional": - x_train = x_train.unsqueeze(1) + x_data = x_data.unsqueeze(1) else: sys.exit("Unrecognized dataset. Please make sure you are using one of the following ['mnist', fashion-mnist', 'kmnist']") @@ -80,8 +86,8 @@ def create_label_dict(dataset : str, nn_model : str) -> dict: for label in range(10): - label_indices = np.where(np.array(y_train) == label)[0] - label_samples_x = x_train[label_indices] + label_indices = np.where(np.array(y_data) == label)[0] + label_samples_x = x_data[label_indices] label_dict[label] = label_samples_x return label_dict @@ -138,7 +144,6 @@ def rotate_images(client: Client, rotation: int) -> None: """ import numpy as np - from math import prod images = client.data['x'] @@ -149,11 +154,8 @@ def rotate_images(client: Client, rotation: int) -> None: for img in images: orig_shape = img.shape - img_flatten = img.flatten() - - rotated_img = np.rot90(img, k=rotation//90) # Rotate image by specified angle + rotated_img = np.rot90(img, k=rotation//90) # Rotate image by specified angle rotated_img = rotated_img.reshape(*orig_shape) - rotated_images.append(rotated_img) client.data['x'] = np.array(rotated_images) @@ -171,37 +173,36 @@ def data_preparation(client : Client, row_exp : dict) -> None: row_exp : The current experiment's global parameters """ + def to_device_tensor(data, device, data_dtype): + + data = torch.tensor(data, dtype=data_dtype) + data.to(device) + return data + import torch from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, TensorDataset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - x_train, x_test, y_train, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y']) - - x_train, x_test = x_train/255.0 , x_test/255.0 - - x_train_tensor = torch.tensor(x_train, dtype=torch.float32) - x_train_tensor.to(device) - + x_data, x_test, y_data, y_test = train_test_split(client.data['x'], client.data['y'], test_size=0.3, random_state=row_exp['seed'],stratify=client.data['y']) + x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.25, random_state=42) - y_train_tensor = torch.tensor(y_train, dtype=torch.long) - y_train_tensor.to(device) + x_train_tensor = to_device_tensor(x_train, device, torch.float32) + y_train_tensor = to_device_tensor(y_train, device, torch.long) - x_test_tensor = torch.tensor(x_test, dtype=torch.float32) - x_test_tensor.to(device) - y_test_tensor = torch.tensor(y_test, dtype=torch.long) - y_test_tensor.to(device) + x_val_tensor = to_device_tensor(x_val, device, torch.float32) + y_val_tensor = to_device_tensor(y_val, device, torch.long) + x_test_tensor = to_device_tensor(x_test, device, torch.float32) + y_test_tensor = to_device_tensor(y_test, device, torch.long) - train_dataset = TensorDataset(x_train_tensor, y_train_tensor) - train_loader = DataLoader(train_dataset, batch_size=32) - - test_dataset = TensorDataset(x_test_tensor, y_test_tensor) - test_loader = DataLoader(test_dataset, batch_size=32) + train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=True) + validation_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=128, shuffle=True) + test_loader = DataLoader( TensorDataset(x_test_tensor, y_test_tensor), batch_size=128, shuffle = True) - setattr(client, 'data_loader', {'train' : train_loader,'test': test_loader}) - setattr(client,'train_test', {'x_train': x_train,'x_test': x_test, 'y_train': y_train, 'y_test': y_test}) + setattr(client, 'data_loader', {'train' : train_loader, 'val' : validation_loader, 'test': test_loader, }) + setattr(client,'train_test', {'x_train': x_train, 'x_val' : x_val, 'x_test': x_test, 'y_train': y_train, 'y_val': y_val, 'y_test': y_test}) return @@ -245,7 +246,7 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: """ - from src.models import SimpleLinear, SimpleConv + from src.models import GenericLinearModel, GenericConvModel from src.utils_fed import init_server_cluster import torch @@ -257,11 +258,11 @@ def setup_experiment(row_exp: dict) -> Tuple[Server, list]: if row_exp['nn_model'] == "linear": - model_server = Server(SimpleLinear(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + model_server = Server(GenericLinearModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) elif row_exp['nn_model'] == "convolutional": - model_server = Server(SimpleConv(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) + model_server = Server(GenericConvModel(in_size=imgs_params[row_exp['dataset']][0], n_channels=imgs_params[row_exp['dataset']][1])) dict_clients = get_clients_data(row_exp['num_clients'], row_exp['num_samples_by_label'], @@ -559,24 +560,25 @@ def centralize_data(list_clients : list) -> Tuple[DataLoader, DataLoader]: import numpy as np x_train = np.concatenate([list_clients[id].train_test['x_train'] for id in range(len(list_clients))],axis = 0) - x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0) - y_train = np.concatenate([list_clients[id].train_test['y_train'] for id in range(len(list_clients))],axis = 0) - y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0) - x_train_tensor = torch.tensor(x_train, dtype=torch.float32) y_train_tensor = torch.tensor(y_train, dtype=torch.long) - + + x_val = np.concatenate([list_clients[id].train_test['x_val'] for id in range(len(list_clients))],axis = 0) + y_val = np.concatenate([list_clients[id].train_test['y_val'] for id in range(len(list_clients))],axis = 0) + x_val_tensor = torch.tensor(x_val, dtype=torch.float32) + y_val_tensor = torch.tensor(y_val, dtype=torch.long) + + x_test = np.concatenate([list_clients[id].train_test['x_test'] for id in range(len(list_clients))],axis = 0) + y_test = np.concatenate([list_clients[id].train_test['y_test'] for id in range(len(list_clients))],axis = 0) x_test_tensor = torch.tensor(x_test, dtype=torch.float32) y_test_tensor = torch.tensor(y_test, dtype=torch.long) - train_dataset = TensorDataset(x_train_tensor, y_train_tensor) - train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) - - test_dataset = TensorDataset(x_test_tensor, y_test_tensor) - test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True) + train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=64, shuffle=True) + val_loader = DataLoader(TensorDataset(x_val_tensor, y_val_tensor), batch_size=64, shuffle=True) + test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=64, shuffle=True) - return train_loader, test_loader + return train_loader, val_loader, test_loader diff --git a/src/utils_fed.py b/src/utils_fed.py index 7544e235ae1785d13d1e6d6e318648dbea79ecf6..b1ee0091680c3381b6007a0d0faed0caeb4a771a 100644 --- a/src/utils_fed.py +++ b/src/utils_fed.py @@ -196,7 +196,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, p_expert_opintion : Parameter to avoid completly random assignment if neeed (default to 0) """ - from src.models import SimpleLinear + from src.models import GenericLinearModel, GenericConvModel import numpy as np import copy @@ -212,7 +212,7 @@ def init_server_cluster(my_server : Server, list_clients : list, row_exp : dict, my_server.num_clusters = row_exp['num_clusters'] - my_server.clusters_models = {cluster_id: SimpleLinear(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])} + my_server.clusters_models = {cluster_id: GenericConvModel(in_size=imgs_params[0], n_channels=imgs_params[1]) for cluster_id in range(row_exp['num_clusters'])} for client in list_clients: @@ -291,8 +291,6 @@ def set_client_cluster(my_server : Server, list_clients : list, row_exp : dict) index_of_min_loss = np.argmin(cluster_losses) - #print(f"client {client.id} with heterogeneity {client.heterogeneity_class} cluster losses:", cluster_losses) - client.model = copy.deepcopy(my_server.clusters_models[index_of_min_loss]) client.cluster_id = index_of_min_loss diff --git a/src/utils_training.py b/src/utils_training.py index 1abd81d1feae0a0b52f2b6111d0a5c186cac53c3..4ae455d81ed1c0161a618bd442b507d28565c6f3 100644 --- a/src/utils_training.py +++ b/src/utils_training.py @@ -1,14 +1,13 @@ import torch import torch.nn as nn -import torch.optim as optim + from torch.utils.data import DataLoader import pandas as pd -from src.models import SimpleLinear +from src.models import ImageClassificationBase from src.fedclass import Server -lr = 0.01 def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : dict) -> pd.DataFrame: @@ -21,7 +20,12 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di main_model : Type of Server model needed list_clients : A list of Client Objects used as nodes in the FL protocol row_exp : The current experiment's global parameters + + Returns: + + df_results : dataframe with the experiment results """ + from src.utils_fed import k_means_clustering import copy import torch @@ -29,24 +33,20 @@ def run_cfl_server_side(model_server : Server, list_clients : list, row_exp : di torch.manual_seed(row_exp['seed']) model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = False) - - model_server.clusters_models= {cluster_id: copy.deepcopy(model_server.model) for cluster_id in range(row_exp['num_clusters'])} - + model_server.clusters_models= {cluster_id: copy.deepcopy(model_server.model) for cluster_id in range(row_exp['num_clusters'])} setattr(model_server, 'num_clusters', row_exp['num_clusters']) k_means_clustering(list_clients, row_exp['num_clusters'], row_exp['seed']) - + model_server = train_federated(model_server, list_clients, row_exp, use_cluster_models = True) for client in list_clients : - acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) - + acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) setattr(client, 'accuracy', acc) - df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) - + return df_results @@ -72,16 +72,15 @@ def run_cfl_client_side(model_server : Server, list_clients : list, row_exp : di for client in list_clients: - client.model, _ = train_central(client.model, client.data_loader['train'], row_exp) + client.model, _ = train_central(client.model, client.data_loader['train'], client.data_loader['val'], row_exp) fedavg(model_server, list_clients) - + set_client_cluster(model_server, list_clients, row_exp) for client in list_clients : acc = test_model(model_server.clusters_models[client.cluster_id], client.data_loader['test']) - setattr(client, 'accuracy', acc) df_results = pd.DataFrame.from_records([c.to_dict() for c in list_clients]) @@ -120,10 +119,8 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - for heterogeneity_class in list_heterogeneities: list_clients_filtered = [client for client in list_clients if client.heterogeneity_class == heterogeneity_class] - - train_loader, test_loader = centralize_data(list_clients_filtered) - - model_trained, _ = train_central(curr_model, train_loader, row_exp) + train_loader, val_loader, test_loader = centralize_data(list_clients_filtered) + model_trained, _ = train_central(curr_model, train_loader, val_loader, row_exp) global_acc = test_model(model_trained, test_loader) @@ -134,11 +131,9 @@ def run_benchmark(main_model : nn.Module, list_clients : list, row_exp : dict) - case 'global-federated': model_server = copy.deepcopy(curr_model) - model_trained = train_federated(model_server, list_clients, row_exp, use_cluster_models = False) - - _, test_loader = centralize_data(list_clients) + _, test_loader = centralize_data(list_clients) global_acc = test_model(model_trained.model, test_loader) for client in list_clients : @@ -161,9 +156,8 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals row_exp: The current experiment's global parameters use_cluster_models: Boolean to determine whether to use personalization by clustering """ - - from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg + from src.utils_fed import send_server_model_to_client, send_cluster_models_to_clients, fedavg for i in range(0, row_exp['federated_rounds']): @@ -179,8 +173,7 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals for client in list_clients: - client.model, curr_acc = train_central(client.model, client.data_loader['train'], row_exp) - + client.model, curr_acc = train_central(client.model, client.data_loader['train'], client.data_loader['val'], row_exp) accs.append(curr_acc) fedavg(main_model, list_clients) @@ -188,54 +181,60 @@ def train_federated(main_model, list_clients, row_exp, use_cluster_models = Fals return main_model -def train_central(main_model, train_loader, row_exp): - """ Main training function for centralized learning +@torch.no_grad() +def evaluate(model : nn.Module, val_loader : DataLoader) -> dict: - Arguments: + """ Returns a dict with loss and accuracy information""" - main_model : Server model used in our experiment - train_loader : DataLoader with the dataset to use for training - row_exp : The current experiment's global parameters + model.eval() + outputs = [model.validation_step(batch) for batch in val_loader] + return model.validation_epoch_end(outputs) - """ - criterion = nn.CrossEntropyLoss() - - optimizer=optim.SGD if row_exp['nn_model'] == "linear" else optim.Adam - optimizer = optimizer(main_model.parameters(), lr=0.01) - - main_model.train() - - for epoch in range(row_exp['centralized_epochs']): - - running_loss = total = correct = 0 - for inputs, labels in train_loader: - optimizer.zero_grad() +def train_central(model : ImageClassificationBase, train_loader : DataLoader, val_loader : DataLoader, row_exp : dict): - outputs = main_model(inputs) + """ Main training function for centralized learning + + Arguments: + model : Server model used in our experiment + train_loader : DataLoader with the training dataset + val_loader : Dataloader with the validation dataset + row_exp : The current experiment's global parameters - _, predicted = torch.max(outputs, 1) + Returns: + (model, history) : base model with trained weights / results at each training step + """ - loss = criterion(outputs, labels) + opt_func=torch.optim.SGD #if row_exp['nn_model'] == "linear" else torch.optim.Adam + lr = 0.001 + history = [] + optimizer = opt_func(model.parameters(), lr) + + for epoch in range(row_exp['centralized_epochs']): + + model.train() + train_losses = [] + + for batch in train_loader: - loss.backward() + loss = model.training_step(batch) + train_losses.append(loss) + loss.backward() optimizer.step() - - running_loss += loss.item() * inputs.size(0) - - total += labels.size(0) - - correct += (predicted == labels).sum().item() - - accuracy = correct / total - - main_model.eval() - - return main_model, accuracy - + optimizer.zero_grad() + + result = evaluate(model, val_loader) + result['train_loss'] = torch.stack(train_losses).mean().item() + + model.epoch_end(epoch, result) + + history.append(result) + + return model, history + def test_model(model : nn.Module, test_loader : DataLoader) -> float: @@ -254,7 +253,6 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float: total = 0 test_loss = 0.0 - with torch.no_grad(): for inputs, labels in test_loader: @@ -262,13 +260,11 @@ def test_model(model : nn.Module, test_loader : DataLoader) -> float: outputs = model(inputs) loss = criterion(outputs, labels) - test_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) total += labels.size(0) - correct += (predicted == labels).sum().item() test_loss = test_loss / len(test_loader.dataset)