你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

implemet FL(fedavg) for handwriting recognition

2022/9/14 0:45:06

sample.py

import numpy as np
from torchvision import datasets

def noniid(dataset,num_users):
    '''
    对MNIST数据集进行noniid划分
    思路:给labels排个序,然后在其中截取随机的两段(每段300个sample),近似noniid
    '''
    
    #变量声明
    a,b = 200,300   # 2*b是一个client拥有的样本数,a是为了设计非独立同分布引入的变量
    #MNIST->60000张图,一个client分600张,client数设为100
    list_a = [i for i in range(a)]
    dict_clients = {i:np.array([],dtype = 'int64') for i in range(num_users)}
    idx = [i for i in range(a*b)]
    labels = dataset.train_labels.numpy()#样本标签
    
    #标签排序
    tmp = np.vstack((idx,labels))#堆积成2行60000列的array
    tmp = tmp[:,tmp[1,:].argsort()]#按labels从小到大进行排序(对应索引也跟着排序)
    idx = tmp[0,:]
    
    #为每个client分配两段不同分布的数据
    for i in range(num_users):
        randset = set(np.random.choice(list_a,2,replace = False))#在1-200中随机抽两个不重复的数构成一个集合
        list_a = list(set(list_a)-randset)#去掉重复的,以防止clients被分配相同的sample
        for rand in randset:
            dict_clients[i] = np.concatenate((dict_clients[i],idx[rand*b,(rand+1)*b]),axis = 0)#一次分300个sample给client[i],分两次
    
    return dict_clients #将字典{客户编号:分配的样本集合}返回