import torch import torch.nn as nn class LossFunctionSelector: def __init__(self): self.available_losses = { 'mse': '均方误差', 'l1': '平均绝对误差', 'cross_entropy': '交叉熵', 'bce': '二分类交叉熵', 'smooth_l1': '平滑L1', 'kl_div': 'KL散度', 'hinge': '合页损失', 'triplet': '三元组损失' } def get_loss(self, name, **kwargs): """获取配置好的损失函数实例""" if name == 'mse': return nn.MSELoss(**kwargs) elif name == 'l1': return nn.L1Loss(**kwargs) elif name == 'cross_entropy': return nn.CrossEntropyLoss(**kwargs) elif name == 'bce': return nn.BCELoss(**kwargs) elif name == 'smooth_l1': return nn.SmoothL1Loss(**kwargs) elif name == 'kl_div': return nn.KLDivLoss(**kwargs) elif name == 'hinge': return nn.HingeEmbeddingLoss(**kwargs) elif name == 'triplet': return nn.TripletMarginLoss(**kwargs) else: raise ValueError(f"不支持的损失函数类型,可选: {list(self.available_losses.keys())}") def print_available(self): """打印支持的损失函数列表""" print("可用损失函数:") for k, v in self.available_losses.items(): print(f"{k.ljust(15)} -> {v}")