import argparse import json import torch from FC_ML_Data.FC_ML_Data_Process.Data_Process_Normalization import Normalizer from FC_ML_NN_Model.Poly_Model import PolyModel from FC_ML_Tool.Serialization import parse_json_file if __name__ == "__main__": parser = argparse.ArgumentParser(description='代理模型训练参数输入') parser.add_argument('--param', default='D:\liyong\project\ModelTrainingPython\FC_ML_Baseline\FC_ML_Baseline_Test\pred\param.json', help='配置参数文件绝对路径') args = parser.parse_args() params = parse_json_file(args.param) print(params) source_dir = params["path"] + "/" model_file = source_dir + params["modelFile"] inputs = [] names = params["output"]["names"] #获取输入特征 for input_value in params["input"]: inputs.append(input_value["value"]) # names.append(input_value["name"]) #记载模型进行预测 input_size = params["modelParams"]["inputSize"] output_size = params["modelParams"]["outputSize"] model_path = params["path"] + "/" + params["modelFile"] device = torch.device('cpu') model = PolyModel(input_size,output_size).to(device) model.load_state_dict(torch.load(model_file)) model.eval() #加载数据处理器 normalization_type = params["modelParams"]["normalizerType"] normalization_max = params["modelParams"]["normalizerMax"] normalization_min = params["modelParams"]["normalizerMin"] normalizer = Normalizer(method=normalization_type) normalizer.load_params(normalization_type,normalization_min[0:input_size],normalization_max[0:input_size]) input_data = normalizer.transform(torch.tensor(inputs)) #执行模型预测 with torch.no_grad(): output_data = model(input_data) # print(f"Prediction result: {output_data.item().tolist():.4f}") normalizer.load_params(normalization_type, normalization_min[-output_size:], normalization_max[-output_size:]) output_data_ori = normalizer.inverse_transform(output_data) # print(f"Prediction real result: {output_data_ori.item().tolist():.4f}") #输出预测结果到文件中 output_datas = output_data_ori.tolist() json_str = {} if len(output_datas) == len(names): for i in range(len(names)): json_str[names[i]] = output_datas[i] with open(source_dir + "forecast.json","w") as f: f.write(json.dumps(json_str, indent=None, ensure_ascii=False))