



值得注意的是,Reader类并不在federatedml库里面,而是一个单独的pipeline库里面的组件。翻阅后发现Reader类继承了Output类。而Output类带有一个关键字data type:

class Output(object):
    def __init__(self, name, data_type='single', has_data=True, has_model=True, has_cache=False, output_unit=1):
        if has_model:
            self.model = Model(name).model
            self.model_output = Model(name).get_all_output()

        if has_data:
            if data_type == "single":
                self.data = SingleOutputData(name).data
                self.data_output = SingleOutputData(name).get_all_output()
            elif data_type == "multi":
                self.data = TraditionalMultiOutputData(name)
                self.data_output = TraditionalMultiOutputData(name).get_all_output()
                self.data = NoLimitOutputData(name, output_unit)
                self.data_output = NoLimitOutputData(name, output_unit).get_all_output()

        if has_cache:
            self.cache = Cache(name).cache
            self.cache_output = Cache(name).get_all_output()

对应的三个data type类也只不过是划分了data,并没有跟分批次相关的步骤

class SingleOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    def data(self):
        return ".".join([self.prefix, IODataType.SINGLE])

    def get_all_output():
        return ["data"]

class TraditionalMultiOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    def train_data(self):
        return ".".join([self.prefix, IODataType.TRAIN])

    def test_data(self):
        return ".".join([self.prefix, IODataType.TEST])

    def validate_data(self):
        return ".".join([self.prefix, IODataType.VALIDATE])

    def get_all_output():
        return [IODataType.TRAIN,

class NoLimitOutputData(object):
    def __init__(self, prefix, output_unit=1):
        self.prefix = prefix
        self.output_unit = output_unit

    def data(self):
        return [self.prefix + "." + "data_" + str(i) for i in range(self.output_unit)]

    def get_all_output(self):
        return ["data_" + str(i) for i in range(self.output_unit)]



最终找到了一个job submitter的东西,也是通过传参,调用服务这种形式去做的Task。这些都是包皮,没有实际的代码。

最后在 federatedml.nn.homo.trainer.fedavg_trainer 里找到FedAvgTrainer,他里面给了参数,里面有batch size:

class FedAVGTrainer(TrainerBase):

    epochs: int >0, epochs to train
    batch_size: int, -1 means full batch
    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
                            mask to local models. These random number masks will eventually cancel out to get 0.
    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
                         is the sample number locally and n_global is the sample number of all clients.
                         if False, simply averaging these models.

    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
                stop training
    tol: float, tol value for early stop

    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
                             every n epochs.
    cuda: bool, use cuda or not
    pin_memory: bool, for pytorch DataLoader
    shuffle: bool, for pytorch DataLoader
    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
    task_type: str, 'auto', 'binary', 'multi', 'regression'
               this option decides the return format of this trainer, and the evaluation type when running validation.
               if auto, will automatically infer your task type from labels and predict results.




95 声望4 粉丝