AIJack icon indicating copy to clipboard operation
AIJack copied to clipboard

Support MPI

Open Koukyosyumei opened this issue 2 years ago • 3 comments

Koukyosyumei avatar Jun 08 '22 13:06 Koukyosyumei


class Client:
    def __init__(self, comm, model, optimizer, myid, dataloader,
                 sparse=False, k=0.3, device="cpu"):
        self.comm = comm
        self.model = model
        self.optimizer = optimizer
        self.myid = myid
        self.dataloader = dataloader
        self.device = device
        self.sparse=sparse
        self.k=k

        self.round = 1

    def action(self):
        self.download_parameters()
        self.train()
        if self.sparse:
            self.send_sparse_gradient()
        else: 
            self.send_gradient()
        self.model.zero_grad()
        self.round += 1

    def send_gradient(self, destination_id=0):
        self.gradients = []
        for param, prev_param in zip(self.model.parameters(), self.prev_parameters):
            self.gradients.append((param.reshape(-1) - prev_param.reshape(-1)).tolist())
        self.comm.send(self.gradients, dest=destination_id, tag=GRADIENTS_TAG)
        
    def send_sparse_gradient(self, destination_id=0):
        self.sparse_gradients = []
        self.sparse_indices = []
        for param, prev_param in zip(self.model.parameters(), self.prev_parameters):
            temp_grad = param.reshape(-1) - prev_param.reshape(-1)
            if (torch.sum(torch.isnan(temp_grad))):
                print("catch nan within the gradients")
                MPI.COMM_WORLD.Abort(SEND_NAN_CODE)
                
            topk_indices = torch.topk(torch.abs(temp_grad), k=int(len(temp_grad)*self.k)).indices
            self.sparse_gradients.append(temp_grad[topk_indices].tolist())
            self.sparse_indices.append(topk_indices.tolist())
        data = [self.sparse_gradients, self.sparse_indices]
        self.comm.send(data, dest=destination_id, tag=SPARSEGRADIENTS_TAG)
        
    def download_parameters(self):
        new_parameters = self.comm.recv(tag=PARAMETERS_TAG)
        for params, new_params in zip(self.model.parameters(), new_parameters):
            params.data = torch.Tensor(new_params).reshape(params.shape).to(self.device)

    def train(self):
        self.prev_parameters = []
        for param in self.model.parameters():
            self.prev_parameters.append(copy.deepcopy(param))
        
        for (data, target) in self.dataloader:
            self.optimizer.zero_grad()
            data = data.to(self.device)
            target = target.to(self.device)
            output = self.model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            self.optimizer.step()

Koukyosyumei avatar Jun 08 '22 13:06 Koukyosyumei


class Server:
    def __init__(self, comm, model, myid, client_ids, dataloader,
                 sparse=False, device="cpu"):
        self.comm = comm
        self.model = model
        self.myid = myid
        self.client_ids = client_ids
        self.dataloader = dataloader
        self.device = device

        self.round = 1
        self.sparse=sparse
        self.num_clients = len(client_ids)

    def send_parameters(self):
        global_parameters = []
        for params in self.model.parameters():
            global_parameters.append(copy.copy(params).reshape(-1).tolist())

        for client_id in self.client_ids:
            self.comm.send(global_parameters, dest=client_id, tag=PARAMETERS_TAG)
        

    def action(self):
        self.send_parameters()
        if self.sparse:
            self._gather_sparse_gradients()
        else: 
            self._gather_gradients()
        self._aggregate()
        self._update_parameters()
        self._evaluate_global_model()
        self.round += 1

    def _gather_gradients(self):
        self.received_gradients = []

        while (len(self.received_gradients) < self.num_clients):
            gradients_flattend = self.comm.recv(tag=GRADIENTS_TAG)
            gradients_reshaped = []
            for params, grad in zip(self.model.parameters(), gradients_flattend):
                gradients_reshaped.append(torch.Tensor(grad).to(self.device).reshape(params.shape))
                if (torch.sum(torch.isnan(gradients_reshaped[-1]))):
                    print(f"the received gradients contains nan")
                    MPI.COMM_WORLD.Abort(RECEIVE_NAN_CODE)
                    
            self.received_gradients.append(gradients_reshaped)
            
    def _gather_sparse_gradients(self):
        self.received_gradients = []

        while (len(self.received_gradients) < self.num_clients):
            sparse_gradients = self.comm.recv(tag=SPARSEGRADIENTS_TAG)
            sparse_gradients_flattend = sparse_gradients[0]
            sparse_indices = sparse_gradients[1]
            gradients_reshaped = []
            for params, grad, idx in zip(self.model.parameters(), sparse_gradients_flattend, sparse_indices):
                temp_grad = torch.zeros_like(params).reshape(-1)
                temp_grad[idx] = torch.Tensor(grad).to(self.device)
                if (torch.sum(torch.isnan(temp_grad))):
                    print(f"the received gradients contains nan")
                    MPI.COMM_WORLD.Abort(RECEIVE_NAN_CODE)
                gradients_reshaped.append(temp_grad.reshape(params.shape))
                    
            self.received_gradients.append(gradients_reshaped)        

    def _aggregate(self):
        self.aggregated_gradients = [torch.zeros_like(params) for params in self.model.parameters()]
        len_gradients = len(self.aggregated_gradients)

        for gradients in self.received_gradients:
            for gradient_id in range(len_gradients):
                self.aggregated_gradients[gradient_id] += (1/self.num_clients) * gradients[gradient_id]   

    def _update_parameters(self):
        for params, grads in zip(self.model.parameters(), self.aggregated_gradients):
            params.data += grads
        
    def _evaluate_global_model(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.dataloader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.dataloader.dataset)
        accuracy = 100. * correct / len(self.dataloader.dataset)
        print(f"Round: {self.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

Koukyosyumei avatar Jun 08 '22 13:06 Koukyosyumei


def fedavg():
    
    fix_seed(seed)
    
    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    print(f"myid={myid} constructs a model with device={device}")

    if myid == 0:
        dataloader = prepare_dataloader(size-1, myid, train=False)
        client_ids = list(range(1, size))
        print("clients = ", client_ids)
        server = Server(comm, model, myid, client_ids, dataloader,
                        sparse=sparse, device=device)
    else:
        dataloader = prepare_dataloader(size-1, myid, train=True)
        client = Client(comm, model, optimizer, myid, dataloader,
                        sparse=sparse, k=k, device=device)

    t1 = MPI.Wtime()
    for _ in range(num_rounds):
        if myid == 0:
            server.action()
            comm.Barrier()
        else:
            client.action()
            comm.Barrier()
    t2 = MPI.Wtime()
    
    t0 = np.ndarray(1, dtype='float64')
    t_w = np.ndarray(1, dtype='float64')
    t0[0] = t2 - t1
    comm.Reduce(t0, t_w, op=MPI.MAX, root=0)
    if myid == 0:
        print("  execution time = : ", t_w[0], "  [sec.] \n")

Koukyosyumei avatar Jun 08 '22 13:06 Koukyosyumei