42 lines
958 B
Python
42 lines
958 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
# 数据生成
|
|
x = np.linspace(-3, 3, 100).reshape(-1, 1)
|
|
y = 2 * x + 1 + np.random.normal(0, 0.5, x.shape)
|
|
x_tensor = torch.FloatTensor(x)
|
|
y_tensor = torch.FloatTensor(y)
|
|
|
|
# 模型定义
|
|
class LinearModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
# 训练配置
|
|
model = LinearModel()
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
|
|
|
# 训练循环
|
|
for epoch in range(1000):
|
|
pred = model(x_tensor)
|
|
loss = criterion(pred, y_tensor)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 结果输出
|
|
w = model.linear.weight.item()
|
|
b = model.linear.bias.item()
|
|
print(f'Final equation: y = {w:.2f}x + {b:.2f}')
|
|
|
|
# 可视化
|
|
plt.scatter(x, y)
|
|
plt.plot(x, w*x + b, 'r-')
|
|
plt.show() |