guog算法笔记

V1

2023/04/06阅读：12主题：默认主题

# 编码器AE + 联邦学习fedavg图像训练

``import torchimport torchvisionimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fimport torchvision.transforms as transformsfrom torchvision.datasets import MNISTimport syft as sy``

``class AE(nn.Module):    def __init__(self):        super(AE, self).__init__()        # 编码器        self.encoder = nn.Sequential(            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),            nn.ReLU(),            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),            nn.ReLU(),            nn.Conv2d(32, 64, kernel_size=7)        )        # 解码器        self.decoder = nn.Sequential(            nn.ConvTranspose2d(64, 32, kernel_size=7),            nn.ReLU(),            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),            nn.ReLU(),            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),            nn.Sigmoid()        )    def forward(self, x):        x = self.encoder(x)        x = self.decoder(x)        return x``

``# 训练函数def train(model_ptr, optimizer, criterion, data_loader, device):    model_ptr.train()    for batch_idx, (data, _) in enumerate(data_loader):        # 发送数据到客户端        data = data.send(model_ptr.location)        target = data.clone().detach()        # 在客户端上进行训练        optimizer.zero_grad()        output = model_ptr(data)        loss = criterion(output, target)        loss.backward()        optimizer.step()        # 获取客户端权重并加权平均        model_ptr.weight.data = model_ptr.weight.data.get() + model_ptr.weight.grad.data        model_ptr.weight.grad.data.zero_()    # 将客户端权重加权平均    model_ptr.weight.data /= len(data_loader)``

``# 测试函数def test(model_ptr, data_loader, device):    model_ptr.eval()    test_loss = 0    with torch.no_grad():        for data, _ in data_loader:            # 发送数据到客户端            data = data.send(model_ptr.location)            target = data.clone().detach()            # 使用联邦学习的模型进行图像重建            output = model_ptr(data)            test_loss += F.mse_loss(output.get(), target, reduction='sum').item()    # 计算平均测试损失    test_loss /= len(data_loader.dataset)    return test_loss``

``# 创建虚拟工人hook = sy.TorchHook(torch)workers = [sy.VirtualWorker(hook, id="worker{}".format(i)) for i in range(3)]# 将数据分配给不同的客户端transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_data = MNIST(root='./data', train=True, download=True, transform=transform)federated_train_loader = sy.FederatedDataLoader(train_data.federate(workers), batch_size=64, shuffle=True, num_workers=0, drop_last=True)``

``# 初始化模型指针model = AE().to(device)model_ptr = model.send(workers[0])# 设置超参数criterion = nn.MSELoss()learning_rate = 0.01optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型num_epochs = 10for epoch in range(num_epochs):    train(model_ptr, optimizer, criterion, federated_train_loader, device)    test_loss = test(model_ptr, federated_train_loader, device)    print('Epoch [{}/{}], Test Loss: {:.4f}'.format(epoch+1, num_epochs, test_loss))# 获取加权平均模型并在本地进行测试avg_model_ptr = model_ptr.copy().move(workers[0])avg_model_ptr.weight.data = torch.zeros_like(avg_model_ptr.weight.data)avg_model_ptr.weight.requires_grad = Falsefor ptr in model_ptr.pointers():    avg_model_ptr.weight.data += ptr.weight.data / len(workers)test_loss = test(avg_model_ptr, federated_train_loader, device)print('Final Test Loss: {:.4f}'.format(test_loss))``

V1