修复预测脚本和训练脚本的执行bug
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
||||
|
||||
def export_model_pt(model,target,name = "model"):
|
||||
script_model = torch.jit.script(model) # 或 torch.jit.trace(model, input)
|
||||
script_model.save(target + name + ".pt")
|
||||
script_model.save(target + name + ".pth")
|
||||
#2 通用格式导出
|
||||
def export_model_onnx(model,input_tensor,target,name="model"):
|
||||
torch.onnx.export(model, input_tensor, target+ name + ".onnx")
|
||||
@@ -11,12 +11,12 @@ def export_model_onnx(model,input_tensor,target,name="model"):
|
||||
def export_model_bin(model,target,name = "weights"):
|
||||
torch.save(model.state_dict(), target + name + ".bin")
|
||||
|
||||
def export_model(model,target,file_name,name):
|
||||
def export_model(model,target,file_name,name,input_tensor):
|
||||
if name == 'bin':
|
||||
return export_model_bin(model,target,file_name)
|
||||
if name == 'onnx':
|
||||
return export_model_onnx(model,target,file_name)
|
||||
if name == 'pt':
|
||||
return export_model_bin(model,target,file_name)
|
||||
return export_model_onnx(model,input_tensor,target,file_name)
|
||||
if name == 'pth':
|
||||
return export_model_pt(model,target,file_name)
|
||||
else:
|
||||
raise ValueError(f"不支持的导出类型")
|
||||
Reference in New Issue
Block a user