diff --git a/Flower_v1/client_2.py b/Flower_v1/client_2.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8b56e121a7f04ab8db7733e572dd44edf38b1f --- /dev/null +++ b/Flower_v1/client_2.py @@ -0,0 +1,119 @@ +# python3 client_1.py <dataset> <partition> <client_num> <IP:PORT> +# python3 client_1.py cifar10 1 3 172.16.66.55:8080 +# this code is updated version with randomseed added to make sure the reproducibility of exp + check the GPU available (to make sure) + +import flwr as fl +import tensorflow as tf +from sklearn.model_selection import train_test_split +import numpy as np +import sys +import random + +physical_devices = tf.config.list_physical_devices('GPU') +if physical_devices: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + print("GPU detected and memory growth enabled.") +else: + print("No GPU detected, using CPU.") + +# Set random seed for reproducibility +def set_random_seed(seed): + random.seed(seed) # Python random module seed + np.random.seed(seed) # Numpy random seed + tf.random.set_seed(seed) # TensorFlow random seed + +# Set a specific seed value +set_random_seed(42) + + +# Function to load and partition the dataset based on client_id +def load_partitioned_data(dataset_name, num_clients, client_id): + if dataset_name == "cifar10": + ''' + total samples train: 50000 + total samples test: 10000 + sample type: animals, vehicles, etc. + ''' + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + x_train = x_train.astype("float32") / 255.0 + x_test = x_test.astype("float32") / 255.0 + num_classes = 10 + input_shape = (32, 32, 3) + elif dataset_name == "mnist": + ''' + total samples train: 60000 + total samples test: 10000 + sample type: handwritten digits + ''' + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0 + x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0 + num_classes = 10 + input_shape = (28, 28, 1) + elif dataset_name == "cifar100": + ''' + total samples train: 50000 + total samples test: 10000 + sample type: animals, objects + ''' + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() + x_train = x_train.astype("float32") / 255.0 + x_test = x_test.astype("float32") / 255.0 + num_classes = 100 + input_shape = (32, 32, 3) + else: + raise ValueError("Dataset not supported. Use 'cifar10', 'mnist', or 'cifar100'.") + + # Partition the dataset among the clients (same data distribution with same number_clients setup) + total_samples = x_train.shape[0] + samples_per_client = total_samples // num_clients + + start = client_id * samples_per_client + end = (client_id + 1) * samples_per_client if client_id != num_clients - 1 else total_samples + + x_client_train = x_train[start:end] + y_client_train = y_train[start:end] + + # Split the client-specific data into training and validation sets + x_train, x_val, y_train, y_val = train_test_split(x_client_train, y_client_train, test_size=0.2, random_state=42) + + return (x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape + +# Get command-line arguments: dataset, client_id, num_clients, and server address +if len(sys.argv) != 5: + print("Usage: python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>") + sys.exit(1) + +dataset_name = sys.argv[1] +client_id = int(sys.argv[2]) +num_clients = int(sys.argv[3]) +server_address = sys.argv[4] + +# Load partitioned data for the client +(x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape = load_partitioned_data(dataset_name, num_clients, client_id) + +# Define the model (MobileNetV2) +model = tf.keras.applications.MobileNetV2(input_shape=input_shape, classes=num_classes, weights=None) +model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) + +# Define Flower client +class FederatedClient(fl.client.NumPyClient): + def get_parameters(self, config): + return model.get_weights() + + def fit(self, parameters, config): + model.set_weights(parameters) + model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_val, y_val)) + return model.get_weights(), len(x_train), {} + + def evaluate(self, parameters, config): + model.set_weights(parameters) + loss, accuracy = model.evaluate(x_test, y_test, verbose=0) + return loss, len(x_test), {"accuracy": float(accuracy)} + +# Start Flower client +if __name__ == "__main__": + fl.client.start_client(server_address=server_address, + client=FederatedClient(), + max_retries=3, # retry connection to server 3 times + max_wait_time=5) # wait 5 seconds before retrying