当前位置: 首页 > 工具软件 > Flower > 使用案例 >

flower联邦学习框架学习笔记

邓阳炎
2023-12-01

FLower联邦学习框架学习笔记

链接: 官方文档官方github主页

部署环境

  在命令行输入如下的命令以安装必要的包PyTorch (torch and torchvision) and Flower (flwr)

pip install -q flwr[simulation] torch torchvision matplotlib

创建客户类

  联邦学习系统包含一个服务器和多个客户端。在Flower中,通过创建flwr.client.Client或flwr.client.NumpyClient等的子类。
  以flwr.client.NumpyClient为例,其子类需要实现三个方法:get_parameters, fit, 和 evaluate

函数名需要实现的功能输入输出
get_parameters返回当前客户端的模型参数self, configList[np.ndarray]
fit从服务器接收并本地训练模型参数,返回更新后的模型参数self, parameters (NDArrays), config (Dict[str, Scalar])parameters (NDArrays),num_examples (int),metrics (Dict[str, Scalar])]
evaluate从服务器接收模型参数,并通过本地数据评价,返回评价结果self, parameters (NDArrays), config (Dict[str, Scalar])loss (float),num_examples (int),metrics (Dict[str, Scalar])]

示例代码:

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

Virtual Client Engine

  当在一台机器上仿真多个客户端时,它们需要共享CPU, GPU, 内存等。这可能会导致迅速耗尽可用的内存资源。
  Flower提供了特殊的模拟功能,仅在实际需要训练或评估时才创建FlowerClient实例。为了使Flower能在必要时创建客户端,需要提供client_fn函数。Flower在需要特定客户端实例来调用fit或evaluate函数时,会调用client_fn创建实例,并通常在使用后被丢弃,故不应该保留任何本地状态。客户端通过客户端ID(cid)标志

def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader)

学习策略

  Flwr.server.strategy中封存了联邦学习方法/算法,例如FedAvg或FedAdagrad。

# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_resources=client_resources,
)

评价策略

  Flower即可以在服务器端也可以在客户端评价聚合模型,即集中评估(Centralized Evaluation)和联合评估(Federated Evaluation):
  集中评估与集中式机器学习评价类似。如果有服务器端的评价数据集,可以不必在评价时将聚合模型发送给客户。在flower中,需要实现evaluate函数,并将该函数名作为参数输入到strategy的evaluate_fn参数中。
  联合评估更复杂但也更强大,这使得我们可以在更大的数据集上评估模型,并得到更贴近现实的结果,然而这种能力是有代价的:如果这些客户端并不总是可用,我们的评价数据集会在连续的学习中发生变化。此外,每个客户端持有的数据集也可以在连续的轮中更改。在flower中通过为FlowerClient部署evaluate方法实现

# The `evaluate` function will be by Flower called after every round
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net().to(DEVICE)
    valloader = valloaders[0]
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, valloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_evaluate=0.3,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())),
    evaluate_fn=evaluate,  # Pass the evaluation function
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)

参数传递

  有时,我们希望服务器端能对客户端 (fit, evaluate)进行配置。如服务器要求客户端训练一定数量的本地epoch。可以使用fit(或evaluate)函数的config参数接受配置字典,通过读取配置字典来调整本地的运行。对于server端,需要编写以round数为输入的函数,并将函数名输入到on-fit-config_fn及on_evaluate_config_fn
客户端使用参数:

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.cid}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

服务端传递参数函数:

def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else 2,  #
    }
    return config
 类似资料: