直接自己写

model是要正则的模型,reg_type选择'l1'还是l2,coef是系数。

def regularization(model:nn.Module, reg_type,coef):
    int_type=int(reg_type[1])
    reg_loss = 0
    for module in model.modules():
        for param in module.parameters():
            reg_loss+=torch.norm(param,int_type)
        
    return reg_loss*coef

代码是一个小例子,对哪个module进行正则,这都可以自己挑选,不必对每一个module都正则。

优化器中添加

一个是Adam或者AdamW优化器里面有weight_decay参数,那个是l2的正则系数


Yonggie
95 声望4 粉丝