import torch import torch.nn as nn class OrdinaryKriging(nn.Module): def __init__(self, variogram_model='gaussian'): super().__init__() self.variogram_model = variogram_model def forward(self, known_coords, known_values, target_coords): # 计算距离矩阵 dists = torch.cdist(known_coords, known_coords) target_dists = torch.cdist(target_coords, known_coords) # 半变异函数(高斯模型) if self.variogram_model == 'gaussian': gamma = 1.0 - torch.exp(-(dists ** 2) / 0.5) target_gamma = 1.0 - torch.exp(-(target_dists ** 2) / 0.5) # 构建克里金矩阵(添加拉格朗日乘子) K = torch.cat([ torch.cat([gamma, torch.ones(len(known_coords), 1)], dim=1), torch.cat([torch.ones(1, len(known_coords)), torch.zeros(1, 1)], dim=1) ]) # 求解权重 weights = torch.linalg.solve( K, torch.cat([known_values, torch.zeros(1)]) ) # 预测值 pred = (weights[:-1] * known_values).sum() return pred