-
huongdm1896 authoredhuongdm1896 authored
client_2.py 4.76 KiB
# python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>
# python3 client_1.py cifar10 1 3 172.16.66.55:8080
# this code is updated version with randomseed added to make sure the reproducibility of exp + check the GPU available (to make sure)
import flwr as fl
import tensorflow as tf
from sklearn.model_selection import train_test_split
import numpy as np
import sys
import random
# list up
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
# Only use 1st GPU
tf.config.set_visible_devices(physical_devices[0], 'GPU')
# Avoid over mem
tf.config.experimental.set_memory_growth(physical_devices[0], True)
print("GPU detected and memory growth enabled. Using GPU:", physical_devices[0].name)
else:
print("No GPU detected, using CPU.")
# Set random seed for reproducibility
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
# Set a specific seed value
set_random_seed(42)
# Function to load and partition the dataset based on client_id
def load_partitioned_data(dataset_name, num_clients, client_id):
if dataset_name == "cifar10":
'''
total samples train: 50000
total samples test: 10000
sample type: animals, vehicles, etc.
'''
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
num_classes = 10
input_shape = (32, 32, 3)
elif dataset_name == "mnist":
'''
total samples train: 60000
total samples test: 10000
sample type: handwritten digits
'''
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0
num_classes = 10
input_shape = (28, 28, 1)
elif dataset_name == "cifar100":
'''
total samples train: 50000
total samples test: 10000
sample type: animals, objects
'''
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
num_classes = 100
input_shape = (32, 32, 3)
else:
raise ValueError("Dataset not supported. Use 'cifar10', 'mnist', or 'cifar100'.")
# Partition the dataset among the clients (same data distribution with same number_clients setup)
total_samples = x_train.shape[0]
samples_per_client = total_samples // num_clients
start = client_id * samples_per_client
end = (client_id + 1) * samples_per_client if client_id != num_clients - 1 else total_samples
x_client_train = x_train[start:end]
y_client_train = y_train[start:end]
# Split the client-specific data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_client_train, y_client_train, test_size=0.2, random_state=42)
return (x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape
# Get command-line arguments: dataset, client_id, num_clients, and server address
if len(sys.argv) != 5:
print("Usage: python3 client_1.py <dataset> <partition> <client_num> <IP:PORT>")
sys.exit(1)
dataset_name = sys.argv[1]
client_id = int(sys.argv[2])
num_clients = int(sys.argv[3])
server_address = sys.argv[4]
# Load partitioned data for the client
(x_train, y_train), (x_val, y_val), (x_test, y_test), num_classes, input_shape = load_partitioned_data(dataset_name, num_clients, client_id)
# Define the model (MobileNetV2)
model = tf.keras.applications.MobileNetV2(input_shape=input_shape, classes=num_classes, weights=None)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
# Define Flower client
class FederatedClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model.get_weights()
def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_val, y_val))
return model.get_weights(), len(x_train), {}
def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
return loss, len(x_test), {"accuracy": float(accuracy)}
# Start Flower client
if __name__ == "__main__":
fl.client.start_client(server_address=server_address,
client=FederatedClient(),
max_retries=3, # retry connection to server 3 times
max_wait_time=5) # wait 5 seconds before retrying