当前位置: 首页 > 工具软件 > 2-plan > 使用案例 >

[pysyft-005]联邦学习pysyft从入门到精通--使用plan

陆啸
2023-12-01
import torch
import torch.nn as nn
import torch.nn.functional as F
import syft as sy

'''
Part 8 - Introduction to Plans
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2008%20-%20Introduction%20to%20Plans.ipynb
'''

'''
演示 plan
一个plan,表示若干个operation的组合,可以是一个函数,也可以是一个类。
plan可以发送到远程节点,可以异步执行。
'''


hook = sy.TorchHook(torch)
hook.local_worker.is_client_worker = False
server = hook.local_worker

x11 = torch.tensor([-1, 2.]).tag('input_data')
x12 = torch.tensor([1, -2.]).tag('input_data2')
x21 = torch.tensor([-1, 2.]).tag('input_data')
x22 = torch.tensor([1, -2.]).tag('input_data2')

#创建远程节点
device_1 = sy.VirtualWorker(hook, id="device_1", data=(x11, x12))
device_2 = sy.VirtualWorker(hook, id="device_2", data=(x21, x22))
devices = device_1, device_2

#plan是一个函数
@sy.func2plan()
def plan_double_abs(x):
    x = x + x
    x = torch.abs(x)
    return x


def test_func_plan():
    #plan在运行前要先build
    print(plan_double_abs.is_built)
    plan_double_abs.build(torch.tensor([1., -2.]))
    print(plan_double_abs.is_built)


    #把plan发送给远程节点
    pointer_plan = plan_double_abs.send(device_1)
                      
                      
    #远程执行build
    pointer_to_data = device_1.search('input_data')[0]
    pointer_to_result = pointer_plan(pointer_to_data)
    print(pointer_to_result)
    pointer_to_result.get()


#plan是个类
class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=0)

def test_class_plan():
    #实例化net
    net = Net()
    #发送先build
    net.build(torch.tensor([1., 2.]))
    #发送到远程节点
    pointer_to_net = net.send(device_1)
    #用plan做计算
    pointer_to_data = device_1.search('input_data')[0]
    pointer_to_result = pointer_to_net(pointer_to_data)
    #输出结果
    print(pointer_to_result)
    print(pointer_to_result.get())

    
if __name__ == '__main__':
    test_func_plan()
    test_class_plan()

 

 类似资料: