import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文:ml-citation{ref="1,2" data="citationList"} plt.rcParams['axes.unicode_minus'] = False # # 手动实现RNN单元 class SimpleRNNCell: def __init__(self, input_size, hidden_size): # 权重初始化 self.W_xh = torch.randn(input_size, hidden_size) * 0.01 self.W_hh = torch.randn(hidden_size, hidden_size) * 0.01 self.b_h = torch.zeros(1, hidden_size) def forward(self, x, h_prev): """ x: 当前输入 (1, input_size) h_prev: 前一刻隐藏状态 (1, hidden_size) """ # RNN核心计算 h_next = torch.tanh(torch.mm(x, self.W_xh) + torch.mm(h_prev, self.W_hh) + self.b_h) return h_next # 示例:处理序列数据 input_size = 3 hidden_size = 4 seq_length = 5 # 创建RNN单元 rnn_cell = SimpleRNNCell(input_size, hidden_size) # 初始化隐藏状态 h = torch.zeros(1, hidden_size) # 模拟输入序列 (5个时间步,每个时间步3维向量) inputs = [torch.randn(1, input_size) for _ in range(seq_length)] # 循环处理序列 hidden_states = [] for t in range(seq_length): h = rnn_cell.forward(inputs[t], h) hidden_states.append(h.detach().numpy()) print(f"时间步 {t + 1}, 隐藏状态: {h}") # 可视化隐藏状态变化 plt.figure(figsize=(10, 6)) for i in range(hidden_size): plt.plot(range(1, seq_length + 1), [h[0, i] for h in hidden_states], label=f'隐藏单元 {i + 1}') plt.title('RNN隐藏状态随时间变化') plt.xlabel('时间步') plt.ylabel('隐藏状态值') plt.legend() plt.grid(True) plt.show()