Skip to content
Snippets Groups Projects
client_2.py 4.76 KiB
# python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>
# python3 client_1.py cifar10 1 3 172.16.66.55:8080
# this code is updated version with randomseed added to make sure the reproducibility of exp + check the GPU available (to make sure)

import flwr as fl
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np
import sys
import random

# list up
physical_devices = tf.config.list_physical_devices('GPU')

if len(physical_devices) > 0:
    # Only use 1st GPU
    tf.config.set_visible_devices(physical_devices[0], 'GPU')

    # Avoid over mem
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    print("GPU detected and memory growth enabled. Using GPU:", physical_devices[0].name)
else:
    print("No GPU detected, using CPU.")
    
# Set random seed for reproducibility
def set_random_seed(seed):
    random.seed(seed)  
    np.random.seed(seed)  
    tf.random.set_seed(seed) 

# Set a specific seed value
set_random_seed(42)


# Function to load and partition the dataset based on client_id
def load_partitioned_data(dataset_name, num_clients, client_id):
    if dataset_name == "cifar10":
        '''
        total samples train: 50000
        total samples test: 10000
        sample type: animals, vehicles, etc. 
        '''
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
        x_train = x_train.astype("float32") / 255.0
        x_test = x_test.astype("float32") / 255.0
        num_classes = 10
        input_shape = (32, 32, 3)
    elif dataset_name == "mnist":
        '''
        total samples train: 60000
        total samples test: 10000
        sample type: handwritten digits
        '''
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
        x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0
        num_classes = 10
        input_shape = (28, 28, 1)
    elif dataset_name == "cifar100":
        '''
        total samples train: 50000
        total samples test: 10000
        sample type: animals, objects
        '''
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
        x_train = x_train.astype("float32") / 255.0
        x_test = x_test.astype("float32") / 255.0
        num_classes = 100
        input_shape = (32, 32, 3)
    else:
        raise ValueError("Dataset not supported. Use 'cifar10', 'mnist', or 'cifar100'.")
    
    # Partition the dataset among the clients (same data distribution with same number_clients setup)
    total_samples = x_train.shape[0]
    samples_per_client = total_samples // num_clients

    start = client_id * samples_per_client
    end = (client_id + 1) * samples_per_client if client_id != num_clients - 1 else total_samples

    x_client_train = x_train[start:end]
    y_client_train = y_train[start:end]
    
    # Split the client-specific data into training and validation sets
    x_train, x_val, y_train, y_val = train_test_split(x_client_train, y_client_train, test_size=0.2, random_state=42)

    return (x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape

# Get command-line arguments: dataset, client_id, num_clients, and server address
if len(sys.argv) != 5:
    print("Usage: python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>")
    sys.exit(1)

dataset_name = sys.argv[1]
client_id = int(sys.argv[2])
num_clients = int(sys.argv[3])
server_address = sys.argv[4]

# Load partitioned data for the client
(x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape = load_partitioned_data(dataset_name, num_clients, client_id)

# Define the model (MobileNetV2)
model = tf.keras.applications.MobileNetV2(input_shape=input_shape, classes=num_classes, weights=None)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Define Flower client
class FederatedClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return model.get_weights()

    def fit(self, parameters, config):
        model.set_weights(parameters)
        model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_val, y_val))
        return model.get_weights(), len(x_train), {}

    def evaluate(self, parameters, config):
        model.set_weights(parameters)
        loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
        return loss, len(x_test), {"accuracy": float(accuracy)}

# Start Flower client
if __name__ == "__main__":
    fl.client.start_client(server_address=server_address, 
                           client=FederatedClient(),
                           max_retries=3, # retry connection to server 3 times
                           max_wait_time=5) # wait 5 seconds before retrying