同步本地最新代码
This commit is contained in:
@@ -3,6 +3,19 @@ import json
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 获取当前脚本(Data_Load.py)所在的目录
|
||||||
|
current_script_dir = os.path.dirname(__file__) # 结果:/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler
|
||||||
|
|
||||||
|
# 从当前目录回退 2 级,得到项目根目录 ModelTrainingPython
|
||||||
|
root_path = os.path.abspath(os.path.join(current_script_dir, "..", ".."))
|
||||||
|
|
||||||
|
# 将根目录添加到 Python 搜索路径
|
||||||
|
sys.path.append(root_path)
|
||||||
|
|
||||||
|
|
||||||
from FC_ML_Data.FC_ML_Data_Process.Data_Process_Normalization import Normalizer
|
from FC_ML_Data.FC_ML_Data_Process.Data_Process_Normalization import Normalizer
|
||||||
from FC_ML_NN_Model.Poly_Model import PolyModel
|
from FC_ML_NN_Model.Poly_Model import PolyModel
|
||||||
from FC_ML_Tool.Serialization import parse_json_file
|
from FC_ML_Tool.Serialization import parse_json_file
|
||||||
|
|||||||
@@ -2,22 +2,22 @@
|
|||||||
"files":["sample1.CSV"],
|
"files":["sample1.CSV"],
|
||||||
"path": "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Test/Train/",
|
"path": "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Test/Train/",
|
||||||
"algorithmParam": {
|
"algorithmParam": {
|
||||||
"inputSize": 8,
|
"inputSize": 9,
|
||||||
"outputSize": 8,
|
"outputSize": 8,
|
||||||
"algorithm": "多项式拟合",
|
"algorithm": "多项式拟合",
|
||||||
"activateFun": "sigmod",
|
"activateFun": "sigmod",
|
||||||
"lossFun": "l1",
|
"lossFun": "l1",
|
||||||
"optimizeFun": "sgd",
|
"optimizeFun": "sgd",
|
||||||
"exportFormat": ".onnx",
|
"exportFormat": "bin",
|
||||||
"trainingRatio": 80,
|
"trainingRatio": 80,
|
||||||
"loadSize": 32,
|
"loadSize": 32,
|
||||||
"studyPercent": 0.001,
|
"studyPercent": 0.001,
|
||||||
"stepCounts": 3,
|
"stepCounts": 3,
|
||||||
"roundPrint": 11,
|
"roundPrint": 10,
|
||||||
"round": 1001,
|
"round": 300,
|
||||||
"preDisposeData": false,
|
"preDisposeData": true,
|
||||||
"disposeMethod": "minmax",
|
"disposeMethod": "minmax",
|
||||||
"dataNoOrder": false
|
"dataNoOrder": true
|
||||||
},
|
},
|
||||||
"algorithm": "基础神经网络NN"
|
"algorithm": "基础神经网络NN"
|
||||||
}
|
}
|
||||||
@@ -19,4 +19,4 @@ def export_model(model,target,file_name,name,input_tensor):
|
|||||||
if name == 'pth':
|
if name == 'pth':
|
||||||
return export_model_pt(model,target,file_name)
|
return export_model_pt(model,target,file_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的导出类型")
|
raise ValueError(f"不支持的导出类型")
|
||||||
Reference in New Issue
Block a user