修复预测脚本和训练脚本的执行bug
This commit is contained in:
@@ -182,7 +182,7 @@ def get_data_from_csv_feature(data_path,skip_rows = 100,sample_rows = 100,normal
|
||||
sampled_indices = torch.arange(0, len(data_ori), skip_rows) # 记录行号
|
||||
return label_name,source_data,normalizer.params["min"],normalizer.params["max"],normalizer.params["mean"],sampled_indices,data_sample
|
||||
|
||||
def get_train_data_from_csv(data_path,normalization = false,normalization_type = 'minmax'):
|
||||
def get_train_data_from_csv(data_path,normalization = True,normalization_type = 'minmax'):
|
||||
"""读取csv数据文件并生成标准化训练数据
|
||||
Args:
|
||||
data_path (str): 文件绝对路径
|
||||
@@ -196,6 +196,8 @@ def get_train_data_from_csv(data_path,normalization = false,normalization_type =
|
||||
|
||||
Examples:
|
||||
get_data_from_csv_feature("D://test.excel")
|
||||
:param normalization_type:
|
||||
:param normalization:
|
||||
"""
|
||||
# 读取前xx行数据
|
||||
df = pd.read_csv(data_path,encoding='gbk')
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
|
||||
def export_model_pt(model,target,name = "model"):
|
||||
script_model = torch.jit.script(model) # 或 torch.jit.trace(model, input)
|
||||
script_model.save(target + name + ".pt")
|
||||
script_model.save(target + name + ".pth")
|
||||
#2 通用格式导出
|
||||
def export_model_onnx(model,input_tensor,target,name="model"):
|
||||
torch.onnx.export(model, input_tensor, target+ name + ".onnx")
|
||||
@@ -11,12 +11,12 @@ def export_model_onnx(model,input_tensor,target,name="model"):
|
||||
def export_model_bin(model,target,name = "weights"):
|
||||
torch.save(model.state_dict(), target + name + ".bin")
|
||||
|
||||
def export_model(model,target,file_name,name):
|
||||
def export_model(model,target,file_name,name,input_tensor):
|
||||
if name == 'bin':
|
||||
return export_model_bin(model,target,file_name)
|
||||
if name == 'onnx':
|
||||
return export_model_onnx(model,target,file_name)
|
||||
if name == 'pt':
|
||||
return export_model_bin(model,target,file_name)
|
||||
return export_model_onnx(model,input_tensor,target,file_name)
|
||||
if name == 'pth':
|
||||
return export_model_pt(model,target,file_name)
|
||||
else:
|
||||
raise ValueError(f"不支持的导出类型")
|
||||
@@ -20,10 +20,10 @@ class Normalizer:
|
||||
self.params['max_abs'] = data.abs().max(dim=0)[0]
|
||||
return self
|
||||
|
||||
def load_params(self,method = "minmax",min_in = 0,max_in = 0,mean_in =0,std=0,max_abs=0):
|
||||
def load_params(self,method = "minmax",min_in = [],max_in = [],mean_in =[],std=[],max_abs=[]):
|
||||
self.method = method
|
||||
self.params['min'] = min_in
|
||||
self.params['max'] = max_in
|
||||
self.params['min'] = torch.tensor(min_in)
|
||||
self.params['max'] = torch.tensor(max_in)
|
||||
self.params['mean'] = mean_in
|
||||
self.params['std'] = std
|
||||
self.params['max_abs'] = max_abs
|
||||
|
||||
Reference in New Issue
Block a user