20 lines
708 B
Python
20 lines
708 B
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
|
||
|
|
class LSTMModel(nn.Module):
|
||
|
|
def __init__(self, input_size, hidden_size, output_size, num_lay):
|
||
|
|
super().__init__()
|
||
|
|
self.lstm = nn.LSTM(
|
||
|
|
input_size=input_size,#输入特征维度
|
||
|
|
hidden_size=hidden_size,#隐藏层维度
|
||
|
|
num_layers=num_lay, # 隐藏层数
|
||
|
|
batch_first=True,
|
||
|
|
# bidirectional = False, # 是否使用双向LSTM
|
||
|
|
# dropout = 0.2 # 添加正则化
|
||
|
|
)
|
||
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
out, (h_n, c_n) = self.lstm(x)
|
||
|
|
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
|
||
|
|
return out
|