9 lines
345 B
Python
9 lines
345 B
Python
import torch
|
||
|
||
# 构建超定方程组 Ax=b (3方程2未知数)
|
||
A = torch.tensor([[2.0, 3], [1, 4], [3, 1]]) # 3x2矩阵
|
||
b = torch.tensor([5.0, 6, 4]).reshape(-1,1) # 3x1向量
|
||
|
||
# 解法1:正规方程 (A^T A)^-1 A^T b
|
||
solution = torch.linalg.lstsq(A, b).solution # PyTorch内置最小二乘
|
||
print(f"最小二乘解:\n{solution.numpy()}") |