35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
|
||
|
|
|
||
|
|
class OrdinaryKriging(nn.Module):
|
||
|
|
def __init__(self, variogram_model='gaussian'):
|
||
|
|
super().__init__()
|
||
|
|
self.variogram_model = variogram_model
|
||
|
|
|
||
|
|
def forward(self, known_coords, known_values, target_coords):
|
||
|
|
# 计算距离矩阵
|
||
|
|
dists = torch.cdist(known_coords, known_coords)
|
||
|
|
target_dists = torch.cdist(target_coords, known_coords)
|
||
|
|
|
||
|
|
# 半变异函数(高斯模型)
|
||
|
|
if self.variogram_model == 'gaussian':
|
||
|
|
gamma = 1.0 - torch.exp(-(dists ** 2) / 0.5)
|
||
|
|
target_gamma = 1.0 - torch.exp(-(target_dists ** 2) / 0.5)
|
||
|
|
|
||
|
|
# 构建克里金矩阵(添加拉格朗日乘子)
|
||
|
|
K = torch.cat([
|
||
|
|
torch.cat([gamma, torch.ones(len(known_coords), 1)], dim=1),
|
||
|
|
torch.cat([torch.ones(1, len(known_coords)), torch.zeros(1, 1)], dim=1)
|
||
|
|
])
|
||
|
|
|
||
|
|
# 求解权重
|
||
|
|
weights = torch.linalg.solve(
|
||
|
|
K,
|
||
|
|
torch.cat([known_values, torch.zeros(1)])
|
||
|
|
)
|
||
|
|
|
||
|
|
# 预测值
|
||
|
|
pred = (weights[:-1] * known_values).sum()
|
||
|
|
return pred
|