Files
ModelTrainingPython/FC_ML_NN/NN_RNN.py

54 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文:ml-citation{ref="1,2" data="citationList"}
plt.rcParams['axes.unicode_minus'] = False #
# 手动实现RNN单元
class SimpleRNNCell:
def __init__(self, input_size, hidden_size):
# 权重初始化
self.W_xh = torch.randn(input_size, hidden_size) * 0.01
self.W_hh = torch.randn(hidden_size, hidden_size) * 0.01
self.b_h = torch.zeros(1, hidden_size)
def forward(self, x, h_prev):
"""
x: 当前输入 (1, input_size)
h_prev: 前一刻隐藏状态 (1, hidden_size)
"""
# RNN核心计算
h_next = torch.tanh(torch.mm(x, self.W_xh) +
torch.mm(h_prev, self.W_hh) +
self.b_h)
return h_next
# 示例:处理序列数据
input_size = 3
hidden_size = 4
seq_length = 5
# 创建RNN单元
rnn_cell = SimpleRNNCell(input_size, hidden_size)
# 初始化隐藏状态
h = torch.zeros(1, hidden_size)
# 模拟输入序列 (5个时间步每个时间步3维向量)
inputs = [torch.randn(1, input_size) for _ in range(seq_length)]
# 循环处理序列
hidden_states = []
for t in range(seq_length):
h = rnn_cell.forward(inputs[t], h)
hidden_states.append(h.detach().numpy())
print(f"时间步 {t + 1}, 隐藏状态: {h}")
# 可视化隐藏状态变化
plt.figure(figsize=(10, 6))
for i in range(hidden_size):
plt.plot(range(1, seq_length + 1), [h[0, i] for h in hidden_states],
label=f'隐藏单元 {i + 1}')
plt.title('RNN隐藏状态随时间变化')
plt.xlabel('时间步')
plt.ylabel('隐藏状态值')
plt.legend()
plt.grid(True)
plt.show()