@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class OrdinaryKriging(nn.Module):
|
||||
def __init__(self, variogram_model='gaussian'):
|
||||
super().__init__()
|
||||
self.variogram_model = variogram_model
|
||||
|
||||
def forward(self, known_coords, known_values, target_coords):
|
||||
# 计算距离矩阵
|
||||
dists = torch.cdist(known_coords, known_coords)
|
||||
target_dists = torch.cdist(target_coords, known_coords)
|
||||
|
||||
# 半变异函数(高斯模型)
|
||||
if self.variogram_model == 'gaussian':
|
||||
gamma = 1.0 - torch.exp(-(dists ** 2) / 0.5)
|
||||
target_gamma = 1.0 - torch.exp(-(target_dists ** 2) / 0.5)
|
||||
|
||||
# 构建克里金矩阵(添加拉格朗日乘子)
|
||||
K = torch.cat([
|
||||
torch.cat([gamma, torch.ones(len(known_coords), 1)], dim=1),
|
||||
torch.cat([torch.ones(1, len(known_coords)), torch.zeros(1, 1)], dim=1)
|
||||
])
|
||||
|
||||
# 求解权重
|
||||
weights = torch.linalg.solve(
|
||||
K,
|
||||
torch.cat([known_values, torch.zeros(1)])
|
||||
)
|
||||
|
||||
# 预测值
|
||||
pred = (weights[:-1] * known_values).sum()
|
||||
return pred
|
||||
Reference in New Issue
Block a user