import torch import torch.optim as optim class OptimizerSelector: def __init__(self): self.available_optimizers = { 'sgd': '随机梯度下降', 'adam': '自适应矩估计', 'rmsprop': '均方根传播', 'adagrad': '自适应梯度', 'adamw': 'Adam权重衰减版' } def get_optimizer(self, model_params, name='sgd', **kwargs): """获取配置好的优化器实例 参数: model_params: 模型参数(可通过model.parameters()获取) name: 优化器类型 **kwargs: 优化器专属参数(lr, weight_decay等) """ if name == 'sgd': return optim.SGD(model_params, **kwargs) elif name == 'adam': return optim.Adam(model_params, **kwargs) elif name == 'rmsprop': return optim.RMSprop(model_params, **kwargs) elif name == 'adagrad': return optim.Adagrad(model_params, **kwargs) elif name == 'adamw': return optim.AdamW(model_params, **kwargs) else: raise ValueError(f"不支持的优化器,可选: {list(self.available_optimizers.keys())}") def print_available(self): """打印支持的优化器列表""" print("可用优化器:") for k, v in self.available_optimizers.items(): print(f"{k.ljust(10)} -> {v}") # 使用示例 if __name__ == "__main__": model = torch.nn.Linear(10, 2) # 示例模型 selector = OptimizerSelector() optimizer = selector.get_optimizer( model.parameters(), name='adamw', lr=0.001, weight_decay=0.01 )