54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文:ml-citation{ref="1,2" data="citationList"}
|
||
plt.rcParams['axes.unicode_minus'] = False #
|
||
|
||
# 手动实现RNN单元
|
||
class SimpleRNNCell:
|
||
def __init__(self, input_size, hidden_size):
|
||
# 权重初始化
|
||
self.W_xh = torch.randn(input_size, hidden_size) * 0.01
|
||
self.W_hh = torch.randn(hidden_size, hidden_size) * 0.01
|
||
self.b_h = torch.zeros(1, hidden_size)
|
||
|
||
def forward(self, x, h_prev):
|
||
"""
|
||
x: 当前输入 (1, input_size)
|
||
h_prev: 前一刻隐藏状态 (1, hidden_size)
|
||
"""
|
||
# RNN核心计算
|
||
h_next = torch.tanh(torch.mm(x, self.W_xh) +
|
||
torch.mm(h_prev, self.W_hh) +
|
||
self.b_h)
|
||
return h_next
|
||
|
||
|
||
# 示例:处理序列数据
|
||
input_size = 3
|
||
hidden_size = 4
|
||
seq_length = 5
|
||
# 创建RNN单元
|
||
rnn_cell = SimpleRNNCell(input_size, hidden_size)
|
||
# 初始化隐藏状态
|
||
h = torch.zeros(1, hidden_size)
|
||
# 模拟输入序列 (5个时间步,每个时间步3维向量)
|
||
inputs = [torch.randn(1, input_size) for _ in range(seq_length)]
|
||
# 循环处理序列
|
||
hidden_states = []
|
||
for t in range(seq_length):
|
||
h = rnn_cell.forward(inputs[t], h)
|
||
hidden_states.append(h.detach().numpy())
|
||
print(f"时间步 {t + 1}, 隐藏状态: {h}")
|
||
# 可视化隐藏状态变化
|
||
plt.figure(figsize=(10, 6))
|
||
for i in range(hidden_size):
|
||
plt.plot(range(1, seq_length + 1), [h[0, i] for h in hidden_states],
|
||
label=f'隐藏单元 {i + 1}')
|
||
plt.title('RNN隐藏状态随时间变化')
|
||
plt.xlabel('时间步')
|
||
plt.ylabel('隐藏状态值')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.show() |