WebsocketServerWorker端代码:start_worker.py
import argparse
import torch as th
from syft.workers.websocket_server import WebsocketServerWorker
import syft as sy
# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument(
"--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument(
"--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="if set, websocket server worker will be started in verbose mode",
)
def main(**kwargs): # pragma: no cover
"""Helper function for spinning up a websocket participant."""
# Create websocket worker
worker = WebsocketServerWorker(**kwargs)
# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)
# Create a dataset using the toy data
dataset = sy.BaseDataset(data, target)
# Tell the worker about the dataset
worker.add_dataset(dataset, key="xor")
# Start worker
worker.start()
return worker
if __name__ == "__main__":
hook = sy.TorchHook(th)
args = parser.parse_args()
kwargs = {
"id": args.id,
"host": args.host,
"port": args.port,
"hook": hook,
"verbose": args.verbose,
}
main(**kwargs)
启动worker
python start_worker.py --host 172.16.5.45 --port 8777 --id alice
客户端代码:
import inspect
import start_worker
print(inspect.getsource(start_worker.main))
# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nn
use_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")
import syft as sy
from syft import workers
hook = sy.TorchHook(th) # hook torch as always :)
class Net(th.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 20)
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Instantiate the model
model = Net()
# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)
# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)
type(traced_model)
# Loss function
@th.jit.script
def loss_fn(target, pred):
return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()
type(loss_fn)
optimizer = "SGD"
batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1 # not used in this example
shuffle = True
train_config = sy.TrainConfig(model=traced_model,
loss_fn=loss_fn,
optimizer=optimizer,
batch_size=batch_size,
optimizer_args=optimizer_args,
epochs=epochs,
shuffle=shuffle)
kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
# Send train config
train_config.send(alice)
# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)
print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))
for epoch in range(10):
loss = alice.fit(dataset_key="xor") # ask alice to train using "xor" dataset
print("-" * 50)
print("Iteration %s: alice's loss: %s" % (epoch, loss))
new_model = train_config.model_ptr.get()
print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))
运行:
python worker-client.py
输出结果:
Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],
[1.],
[0.],
[0.]])
Pred: tensor([[ 0.1258],
[-0.0994],
[ 0.0033],
[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)
Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],
[1.],
[0.],
[0.]])
Pred: tensor([[0.6553],
[0.3781],
[0.4834],
[0.4477]], grad_fn=<DifferentiableGraphBackward>)