# python3 client_1.py <dataset> <partition> <client_num> <IP:PORT> # python3 client_1.py cifar10 1 3 172.16.66.55:8080 # this code is updated version with randomseed added to make sure the reproducibility of exp + check the GPU available (to make sure) import flwr as fl import tensorflow as tf from sklearn.model_selection import train_test_split import numpy as np import sys import random # list up physical_devices = tf.config.list_physical_devices('GPU') if len(physical_devices) > 0: # Only use 1st GPU tf.config.set_visible_devices(physical_devices[0], 'GPU') # Avoid over mem tf.config.experimental.set_memory_growth(physical_devices[0], True) print("GPU detected and memory growth enabled. Using GPU:", physical_devices[0].name) else: print("No GPU detected, using CPU.") # Set random seed for reproducibility def set_random_seed(seed): random.seed(seed) np.random.seed(seed) tf.random.set_seed(seed) # Set a specific seed value set_random_seed(42) # Function to load and partition the dataset based on client_id def load_partitioned_data(dataset_name, num_clients, client_id): if dataset_name == "cifar10": ''' total samples train: 50000 total samples test: 10000 sample type: animals, vehicles, etc. ''' (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() x_train = x_train.astype("float32") / 255.0 x_test = x_test.astype("float32") / 255.0 num_classes = 10 input_shape = (32, 32, 3) elif dataset_name == "mnist": ''' total samples train: 60000 total samples test: 10000 sample type: handwritten digits ''' (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0 x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0 num_classes = 10 input_shape = (28, 28, 1) elif dataset_name == "cifar100": ''' total samples train: 50000 total samples test: 10000 sample type: animals, objects ''' (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() x_train = x_train.astype("float32") / 255.0 x_test = x_test.astype("float32") / 255.0 num_classes = 100 input_shape = (32, 32, 3) else: raise ValueError("Dataset not supported. Use 'cifar10', 'mnist', or 'cifar100'.") # Partition the dataset among the clients (same data distribution with same number_clients setup) total_samples = x_train.shape[0] samples_per_client = total_samples // num_clients start = client_id * samples_per_client end = (client_id + 1) * samples_per_client if client_id != num_clients - 1 else total_samples x_client_train = x_train[start:end] y_client_train = y_train[start:end] # Split the client-specific data into training and validation sets x_train, x_val, y_train, y_val = train_test_split(x_client_train, y_client_train, test_size=0.2, random_state=42) return (x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape # Get command-line arguments: dataset, client_id, num_clients, and server address if len(sys.argv) != 5: print("Usage: python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>") sys.exit(1) dataset_name = sys.argv[1] client_id = int(sys.argv[2]) num_clients = int(sys.argv[3]) server_address = sys.argv[4] # Load partitioned data for the client (x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape = load_partitioned_data(dataset_name, num_clients, client_id) # Define the model (MobileNetV2) model = tf.keras.applications.MobileNetV2(input_shape=input_shape, classes=num_classes, weights=None) model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # Define Flower client class FederatedClient(fl.client.NumPyClient): def get_parameters(self, config): return model.get_weights() def fit(self, parameters, config): model.set_weights(parameters) model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_val, y_val)) return model.get_weights(), len(x_train), {} def evaluate(self, parameters, config): model.set_weights(parameters) loss, accuracy = model.evaluate(x_test, y_test, verbose=0) return loss, len(x_test), {"accuracy": float(accuracy)} # Start Flower client if __name__ == "__main__": fl.client.start_client(server_address=server_address, client=FederatedClient(), max_retries=3, # retry connection to server 3 times max_wait_time=5) # wait 5 seconds before retrying