43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
|
||
|
|
|
||
|
|
class LossFunctionSelector:
|
||
|
|
def __init__(self):
|
||
|
|
self.available_losses = {
|
||
|
|
'mse': '均方误差',
|
||
|
|
'l1': '平均绝对误差',
|
||
|
|
'cross_entropy': '交叉熵',
|
||
|
|
'bce': '二分类交叉熵',
|
||
|
|
'smooth_l1': '平滑L1',
|
||
|
|
'kl_div': 'KL散度',
|
||
|
|
'hinge': '合页损失',
|
||
|
|
'triplet': '三元组损失'
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_loss(self, name, **kwargs):
|
||
|
|
"""获取配置好的损失函数实例"""
|
||
|
|
if name == 'mse':
|
||
|
|
return nn.MSELoss(**kwargs)
|
||
|
|
elif name == 'l1':
|
||
|
|
return nn.L1Loss(**kwargs)
|
||
|
|
elif name == 'cross_entropy':
|
||
|
|
return nn.CrossEntropyLoss(**kwargs)
|
||
|
|
elif name == 'bce':
|
||
|
|
return nn.BCELoss(**kwargs)
|
||
|
|
elif name == 'smooth_l1':
|
||
|
|
return nn.SmoothL1Loss(**kwargs)
|
||
|
|
elif name == 'kl_div':
|
||
|
|
return nn.KLDivLoss(**kwargs)
|
||
|
|
elif name == 'hinge':
|
||
|
|
return nn.HingeEmbeddingLoss(**kwargs)
|
||
|
|
elif name == 'triplet':
|
||
|
|
return nn.TripletMarginLoss(**kwargs)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"不支持的损失函数类型,可选: {list(self.available_losses.keys())}")
|
||
|
|
|
||
|
|
def print_available(self):
|
||
|
|
"""打印支持的损失函数列表"""
|
||
|
|
print("可用损失函数:")
|
||
|
|
for k, v in self.available_losses.items():
|
||
|
|
print(f"{k.ljust(15)} -> {v}")
|