Revert "新建仓库维护数据预测项目"

This reverts commit 516126d2a5.
This commit is contained in:
2025-10-17 15:08:09 +08:00
parent 516126d2a5
commit c07dbb586d
72 changed files with 0 additions and 82332 deletions

View File

@@ -1,52 +0,0 @@
import torch
import torch.optim as optim
class OptimizerSelector:
def __init__(self):
self.available_optimizers = {
'sgd': '随机梯度下降',
'adam': '自适应矩估计',
'rmsprop': '均方根传播',
'adagrad': '自适应梯度',
'adamw': 'Adam权重衰减版'
}
def get_optimizer(self, model_params, name='sgd', **kwargs):
"""获取配置好的优化器实例
参数:
model_params: 模型参数(可通过model.parameters()获取)
name: 优化器类型
**kwargs: 优化器专属参数(lr, weight_decay等)
"""
if name == 'sgd':
return optim.SGD(model_params, **kwargs)
elif name == 'adam':
return optim.Adam(model_params, **kwargs)
elif name == 'rmsprop':
return optim.RMSprop(model_params, **kwargs)
elif name == 'adagrad':
return optim.Adagrad(model_params, **kwargs)
elif name == 'adamw':
return optim.AdamW(model_params, **kwargs)
else:
raise ValueError(f"不支持的优化器,可选: {list(self.available_optimizers.keys())}")
def print_available(self):
"""打印支持的优化器列表"""
print("可用优化器:")
for k, v in self.available_optimizers.items():
print(f"{k.ljust(10)} -> {v}")
# 使用示例
if __name__ == "__main__":
model = torch.nn.Linear(10, 2) # 示例模型
selector = OptimizerSelector()
optimizer = selector.get_optimizer(
model.parameters(),
name='adamw',
lr=0.001,
weight_decay=0.01
)