Files
ModelTrainingPython/FC_ML_NN/NN_RNN.py

54 lines
1.7 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文:ml-citation{ref="1,2" data="citationList"}
plt.rcParams['axes.unicode_minus'] = False #
# 手动实现RNN单元
class SimpleRNNCell:
def __init__(self, input_size, hidden_size):
# 权重初始化
self.W_xh = torch.randn(input_size, hidden_size) * 0.01
self.W_hh = torch.randn(hidden_size, hidden_size) * 0.01
self.b_h = torch.zeros(1, hidden_size)
def forward(self, x, h_prev):
"""
x: 当前输入 (1, input_size)
h_prev: 前一刻隐藏状态 (1, hidden_size)
"""
# RNN核心计算
h_next = torch.tanh(torch.mm(x, self.W_xh) +
torch.mm(h_prev, self.W_hh) +
self.b_h)
return h_next
# 示例:处理序列数据
input_size = 3
hidden_size = 4
seq_length = 5
# 创建RNN单元
rnn_cell = SimpleRNNCell(input_size, hidden_size)
# 初始化隐藏状态
h = torch.zeros(1, hidden_size)
# 模拟输入序列 (5个时间步每个时间步3维向量)
inputs = [torch.randn(1, input_size) for _ in range(seq_length)]
# 循环处理序列
hidden_states = []
for t in range(seq_length):
h = rnn_cell.forward(inputs[t], h)
hidden_states.append(h.detach().numpy())
print(f"时间步 {t + 1}, 隐藏状态: {h}")
# 可视化隐藏状态变化
plt.figure(figsize=(10, 6))
for i in range(hidden_size):
plt.plot(range(1, seq_length + 1), [h[0, i] for h in hidden_states],
label=f'隐藏单元 {i + 1}')
plt.title('RNN隐藏状态随时间变化')
plt.xlabel('时间步')
plt.ylabel('隐藏状态值')
plt.legend()
plt.grid(True)
plt.show()