链接: 官方文档,官方github主页
在命令行输入如下的命令以安装必要的包PyTorch (torch and torchvision) and Flower (flwr)
pip install -q flwr[simulation] torch torchvision matplotlib
联邦学习系统包含一个服务器和多个客户端。在Flower中,通过创建flwr.client.Client或flwr.client.NumpyClient等的子类。
以flwr.client.NumpyClient为例,其子类需要实现三个方法:get_parameters, fit, 和 evaluate
函数名 | 需要实现的功能 | 输入 | 输出 |
---|---|---|---|
get_parameters | 返回当前客户端的模型参数 | self, config | List[np.ndarray] |
fit | 从服务器接收并本地训练模型参数,返回更新后的模型参数 | self, parameters (NDArrays), config (Dict[str, Scalar]) | parameters (NDArrays),num_examples (int),metrics (Dict[str, Scalar])] |
evaluate | 从服务器接收模型参数,并通过本地数据评价,返回评价结果 | self, parameters (NDArrays), config (Dict[str, Scalar]) | loss (float),num_examples (int),metrics (Dict[str, Scalar])] |
示例代码:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, net, trainloader, valloader):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
return get_parameters(self.net)
def fit(self, parameters, config):
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=1)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
当在一台机器上仿真多个客户端时,它们需要共享CPU, GPU, 内存等。这可能会导致迅速耗尽可用的内存资源。
Flower提供了特殊的模拟功能,仅在实际需要训练或评估时才创建FlowerClient实例。为了使Flower能在必要时创建客户端,需要提供client_fn函数。Flower在需要特定客户端实例来调用fit或evaluate函数时,会调用client_fn创建实例,并通常在使用后被丢弃,故不应该保留任何本地状态。客户端通过客户端ID(cid)标志
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
# Load model
net = Net().to(DEVICE)
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
# Create a single Flower client representing a single organization
return FlowerClient(net, trainloader, valloader)
Flwr.server.strategy中封存了联邦学习方法/算法,例如FedAvg或FedAdagrad。
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Sample 100% of available clients for training
fraction_evaluate=0.5, # Sample 50% of available clients for evaluation
min_fit_clients=10, # Never sample less than 10 clients for training
min_evaluate_clients=5, # Never sample less than 5 clients for evaluation
min_available_clients=10, # Wait until all 10 clients are available
)
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
client_resources = {"num_gpus": 1}
# Start simulation
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=5),
strategy=strategy,
client_resources=client_resources,
)
Flower即可以在服务器端也可以在客户端评价聚合模型,即集中评估(Centralized Evaluation)和联合评估(Federated Evaluation):
集中评估与集中式机器学习评价类似。如果有服务器端的评价数据集,可以不必在评价时将聚合模型发送给客户。在flower中,需要实现evaluate函数,并将该函数名作为参数输入到strategy的evaluate_fn参数中。
联合评估更复杂但也更强大,这使得我们可以在更大的数据集上评估模型,并得到更贴近现实的结果,然而这种能力是有代价的:如果这些客户端并不总是可用,我们的评价数据集会在连续的学习中发生变化。此外,每个客户端持有的数据集也可以在连续的轮中更改。在flower中通过为FlowerClient部署evaluate方法实现
# The `evaluate` function will be by Flower called after every round
def evaluate(
server_round: int,
parameters: fl.common.NDArrays,
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
net = Net().to(DEVICE)
valloader = valloaders[0]
set_parameters(net, parameters) # Update model with the latest parameters
loss, accuracy = test(net, valloader)
print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
return loss, {"accuracy": accuracy}
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.3,
fraction_evaluate=0.3,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=NUM_CLIENTS,
initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
evaluate_fn=evaluate, # Pass the evaluation function
)
fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=3), # Just three rounds
strategy=strategy,
client_resources=client_resources,
)
有时,我们希望服务器端能对客户端 (fit, evaluate)进行配置。如服务器要求客户端训练一定数量的本地epoch。可以使用fit(或evaluate)函数的config参数接受配置字典,通过读取配置字典来调整本地的运行。对于server端,需要编写以round数为输入的函数,并将函数名输入到on-fit-config_fn及on_evaluate_config_fn
客户端使用参数:
class FlowerClient(fl.client.NumPyClient):
def __init__(self, cid, net, trainloader, valloader):
self.cid = cid
self.net = net
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self, config):
print(f"[Client {self.cid}] get_parameters")
return get_parameters(self.net)
def fit(self, parameters, config):
# Read values from config
server_round = config["server_round"]
local_epochs = config["local_epochs"]
# Use values provided by the config
print(f"[Client {self.cid}, round {server_round}] fit, config: {config}")
set_parameters(self.net, parameters)
train(self.net, self.trainloader, epochs=local_epochs)
return get_parameters(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
print(f"[Client {self.cid}] evaluate, config: {config}")
set_parameters(self.net, parameters)
loss, accuracy = test(self.net, self.valloader)
return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
def client_fn(cid) -> FlowerClient:
net = Net().to(DEVICE)
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
return FlowerClient(cid, net, trainloader, valloader)
服务端传递参数函数:
def fit_config(server_round: int):
"""Return training configuration dict for each round.
Perform two rounds of training with one local epoch, increase to two local
epochs afterwards.
"""
config = {
"server_round": server_round, # The current round of federated learning
"local_epochs": 1 if server_round < 2 else 2, #
}
return config