切换git用户重新进行项目首次归档
This commit is contained in:
52
FC_ML_Optim_Function/Optimizer_Selector.py
Normal file
52
FC_ML_Optim_Function/Optimizer_Selector.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class OptimizerSelector:
|
||||
def __init__(self):
|
||||
self.available_optimizers = {
|
||||
'sgd': '随机梯度下降',
|
||||
'adam': '自适应矩估计',
|
||||
'rmsprop': '均方根传播',
|
||||
'adagrad': '自适应梯度',
|
||||
'adamw': 'Adam权重衰减版'
|
||||
}
|
||||
|
||||
def get_optimizer(self, model_params, name='sgd', **kwargs):
|
||||
"""获取配置好的优化器实例
|
||||
|
||||
参数:
|
||||
model_params: 模型参数(可通过model.parameters()获取)
|
||||
name: 优化器类型
|
||||
**kwargs: 优化器专属参数(lr, weight_decay等)
|
||||
"""
|
||||
if name == 'sgd':
|
||||
return optim.SGD(model_params, **kwargs)
|
||||
elif name == 'adam':
|
||||
return optim.Adam(model_params, **kwargs)
|
||||
elif name == 'rmsprop':
|
||||
return optim.RMSprop(model_params, **kwargs)
|
||||
elif name == 'adagrad':
|
||||
return optim.Adagrad(model_params, **kwargs)
|
||||
elif name == 'adamw':
|
||||
return optim.AdamW(model_params, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的优化器,可选: {list(self.available_optimizers.keys())}")
|
||||
|
||||
def print_available(self):
|
||||
"""打印支持的优化器列表"""
|
||||
print("可用优化器:")
|
||||
for k, v in self.available_optimizers.items():
|
||||
print(f"{k.ljust(10)} -> {v}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
model = torch.nn.Linear(10, 2) # 示例模型
|
||||
selector = OptimizerSelector()
|
||||
optimizer = selector.get_optimizer(
|
||||
model.parameters(),
|
||||
name='adamw',
|
||||
lr=0.001,
|
||||
weight_decay=0.01
|
||||
)
|
||||
0
FC_ML_Optim_Function/__init__.py
Normal file
0
FC_ML_Optim_Function/__init__.py
Normal file
Reference in New Issue
Block a user