53 lines
1.6 KiB
Python
53 lines
1.6 KiB
Python
|
|
import torch
|
||
|
|
import torch.optim as optim
|
||
|
|
|
||
|
|
|
||
|
|
class OptimizerSelector:
|
||
|
|
def __init__(self):
|
||
|
|
self.available_optimizers = {
|
||
|
|
'sgd': '随机梯度下降',
|
||
|
|
'adam': '自适应矩估计',
|
||
|
|
'rmsprop': '均方根传播',
|
||
|
|
'adagrad': '自适应梯度',
|
||
|
|
'adamw': 'Adam权重衰减版'
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_optimizer(self, model_params, name='sgd', **kwargs):
|
||
|
|
"""获取配置好的优化器实例
|
||
|
|
|
||
|
|
参数:
|
||
|
|
model_params: 模型参数(可通过model.parameters()获取)
|
||
|
|
name: 优化器类型
|
||
|
|
**kwargs: 优化器专属参数(lr, weight_decay等)
|
||
|
|
"""
|
||
|
|
if name == 'sgd':
|
||
|
|
return optim.SGD(model_params, **kwargs)
|
||
|
|
elif name == 'adam':
|
||
|
|
return optim.Adam(model_params, **kwargs)
|
||
|
|
elif name == 'rmsprop':
|
||
|
|
return optim.RMSprop(model_params, **kwargs)
|
||
|
|
elif name == 'adagrad':
|
||
|
|
return optim.Adagrad(model_params, **kwargs)
|
||
|
|
elif name == 'adamw':
|
||
|
|
return optim.AdamW(model_params, **kwargs)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"不支持的优化器,可选: {list(self.available_optimizers.keys())}")
|
||
|
|
|
||
|
|
def print_available(self):
|
||
|
|
"""打印支持的优化器列表"""
|
||
|
|
print("可用优化器:")
|
||
|
|
for k, v in self.available_optimizers.items():
|
||
|
|
print(f"{k.ljust(10)} -> {v}")
|
||
|
|
|
||
|
|
|
||
|
|
# 使用示例
|
||
|
|
if __name__ == "__main__":
|
||
|
|
model = torch.nn.Linear(10, 2) # 示例模型
|
||
|
|
selector = OptimizerSelector()
|
||
|
|
optimizer = selector.get_optimizer(
|
||
|
|
model.parameters(),
|
||
|
|
name='adamw',
|
||
|
|
lr=0.001,
|
||
|
|
weight_decay=0.01
|
||
|
|
)
|