新建仓库维护数据预测项目
This commit is contained in:
20
FC_ML_NN/NN_LSTM.py
Normal file
20
FC_ML_NN/NN_LSTM.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LSTMModel(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size, num_lay):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_size,#输入特征维度
|
||||
hidden_size=hidden_size,#隐藏层维度
|
||||
num_layers=num_lay, # 隐藏层数
|
||||
batch_first=True,
|
||||
# bidirectional = False, # 是否使用双向LSTM
|
||||
# dropout = 0.2 # 添加正则化
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
out, (h_n, c_n) = self.lstm(x)
|
||||
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
|
||||
return out
|
||||
Reference in New Issue
Block a user