Files
ModelTrainingPython/FC_ML_NN/NN_LSTM.py

20 lines
708 B
Python
Raw Normal View History

import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_lay):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,#输入特征维度
hidden_size=hidden_size,#隐藏层维度
num_layers=num_lay, # 隐藏层数
batch_first=True,
# bidirectional = False, # 是否使用双向LSTM
# dropout = 0.2 # 添加正则化
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, (h_n, c_n) = self.lstm(x)
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
return out