不知道数据来自哪里?针对未知坐标的贝叶斯建模
Don't know where your data is from? Bayesian modeling for unknown coordinates

原始链接: https://christopherkrapu.com/blog/2026/dont-know-where-your-data-is-from/

此代码通过模拟位置误差对数据表面的影响,对空间不确定性进行对比可视化。 针对不同程度的位置噪声 ($\sigma_s$),该脚本执行以下操作: 1. **计算高斯过程 (GP) 后验均值:** 使用预先计算的贝叶斯推断模型 (`location_error_idatas`),在已知噪声输入坐标的情况下估计潜在的空间表面 $f(s)$。 2. **生成朴素核密度估计 (KDE) 对照:** 使用标准的核密度估计计算基准表面,并根据高斯过程推导出的长度尺度 ($\ell$) 进行平滑处理。 3. **可视化不确定性:** 生成多面板图形,展示: * **扰动坐标:** 展示真实位置与带误差半径的噪声观测值的散点图。 * **后验密度:** 潜在真实位置 (`X_true`) 与高斯过程后验均值的 KDE 图。 * **空间表面:** 并排的热力图,用于比较不同噪声水平下高斯过程推断的表面与朴素平滑方法。 最终输出是一个高分辨率的多面板网格图,旨在展示与朴素插值方法相比,将位置不确定性纳入高斯过程模型能提供更稳健的空间现象估计。

Hacker News 最新 | 过往 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 不知道数据来源?针对未知坐标的贝叶斯建模 (christopherkrapu.com) 5 点,由 ckrapu 发布于 2 小时前 | 隐藏 | 过往 | 收藏 | 讨论 | 帮助 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系 搜索:
相关文章

原文
location_surface_grids = {}
with location_error_gp_model:
    for multiplier, X_noisy_value in zip(multipliers, noisy_xs):
        pm.set_data({"X_noisy": X_noisy_value, "σ_s": multiplier})
        posterior_mean_point = {
            name: location_error_idatas[multiplier].posterior[name].mean(("chain", "draw")).values
            for name in ["μ", "σ", "", "σ0", "Δs", "X_true"]
        }
        f_mean, _ = gp_location.predict(Xnew, point=posterior_mean_point, diag=True, pred_noise=False)
        location_surface_grids[multiplier] = f_mean.reshape(n_prediction_grid, n_prediction_grid)

naive_kde_grids = {}
kde_bandwidths = {
    multiplier: location_error_idatas[multiplier].posterior[""].mean(("chain", "draw")).item()
    for multiplier in multipliers
}
for multiplier, X_noisy_value in zip(multipliers, noisy_xs):
    squared_distance = (
        (Xnew[:, 0, None] - X_noisy_value[:, 0]) ** 2
        + (Xnew[:, 1, None] - X_noisy_value[:, 1]) ** 2
    )
    weights = np.exp(-0.5 * squared_distance / kde_bandwidths[multiplier]**2)
    naive_kde_grids[multiplier] = (weights @ y_walker / weights.sum(axis=1)).reshape(n_prediction_grid, n_prediction_grid)

location_norm = colors.Normalize(
    vmin=min(y_walker.min(), *(grid.min() for grid in location_surface_grids.values())),
    vmax=max(y_walker.max(), *(grid.max() for grid in location_surface_grids.values())),
)
point_colors = [theme["gunmetal"], theme["sepia"], theme["rust"], theme["steel"]]
arrow_head_width = 0.018 * max(x_range, y_range)

fig = plt.figure(figsize=(5, 11.2 * (5 / 7)))
gs = fig.add_gridspec(4, len(multipliers), hspace=0.08, wspace=0.08)
axes = np.array([[fig.add_subplot(gs[row, col]) for col in range(len(multipliers))] for row in range(4)])
panel_idx = 0
surface_meshes = []

for col, (multiplier, X_noisy_value) in enumerate(zip(multipliers, noisy_xs)):
    ax = axes[0, col]
    ax.scatter(X_walker[:, 0], X_walker[:, 1], facecolor=theme["paper"], edgecolor=theme["gunmetal"], s=14, linewidth=0.35, alpha=0.5)
    ax.scatter(X_noisy_value[:, 0], X_noisy_value[:, 1], c=y_walker, cmap=sepia_cmap, norm=location_norm, s=16, edgecolor=theme["paper"], linewidth=0.25, alpha=0.5)
    for point_idx, point_color in zip(selected_location_idx, point_colors):
        circle = Circle(X_noisy_value[point_idx], radius=multiplier, facecolor=colors.to_rgba(theme["steel"], 0.2), edgecolor=theme["gunmetal"], linewidth=0.6)
        ax.add_patch(circle)
        dx, dy = X_noisy_value[point_idx] - X_walker[point_idx]
        ax.arrow(X_walker[point_idx, 0], X_walker[point_idx, 1], dx, dy, color=plot_text_color, linewidth=0.8, length_includes_head=True, head_width=arrow_head_width, head_length=arrow_head_width)
    ax.scatter(X_walker[selected_location_idx, 0], X_walker[selected_location_idx, 1], facecolor=theme["paper"], edgecolor=plot_text_color, s=44, linewidth=0.8)
    ax.scatter(X_noisy_value[selected_location_idx, 0], X_noisy_value[selected_location_idx, 1], c=y_walker[selected_location_idx], cmap=sepia_cmap, norm=location_norm, s=44, edgecolor=plot_text_color, linewidth=0.8)
    ax.xaxis.set_label_position("top"); ax.set_xlabel(rf"$\sigma_s = {multiplier:.0f}$ m")

    ax = axes[2, col]
    mesh = ax.pcolormesh(x_new_mesh, y_new_mesh, location_surface_grids[multiplier], cmap=sepia_cmap, norm=location_norm, shading="auto")
    surface_meshes.append(mesh)

    ax = axes[3, col]
    ax.pcolormesh(x_new_mesh, y_new_mesh, naive_kde_grids[multiplier], cmap=sepia_cmap, norm=location_norm, shading="auto")

    ax = axes[1, col]
    ax.pcolormesh(x_new_mesh, y_new_mesh, location_surface_grids[multiplier], cmap=sepia_cmap, norm=location_norm, shading="auto", alpha=0.25)
    X_true_samples = location_error_idatas[multiplier].posterior["X_true"].stack(sample=("chain", "draw")).transpose("obs", "coord", "sample")
    for point_idx, point_color in zip(selected_location_idx, point_colors):
        x_draws = X_true_samples.isel(obs=point_idx).sel(coord="x").values
        y_draws = X_true_samples.isel(obs=point_idx).sel(coord="y").values
        sns.kdeplot(x=x_draws, y=y_draws, levels=3, color=point_color, linewidths=0.9, fill=False, ax=ax)
        ax.scatter(X_walker[point_idx, 0], X_walker[point_idx, 1], marker="x", color=point_color, s=34, linewidth=1.1)
        ax.scatter(X_noisy_value[point_idx, 0], X_noisy_value[point_idx, 1], marker="o", facecolor=theme["paper"], edgecolor=point_color, s=34, linewidth=1.1)

for row, row_label in enumerate(["Perturbed coordinates", "Posterior location density", r"Posterior mean of $f(s)$", "Naive smoothed KDE"]):
    axes[row, 0].set_ylabel(row_label)

for ax in axes.ravel():
    ax.set_xlim(*x_limits); ax.set_ylim(*y_limits); ax.set_aspect("equal"); ax.grid(False)
    ax.tick_params(axis="both", which="both", bottom=False, left=False, top=False, right=False, labelbottom=False, labelleft=False, labeltop=False, labelright=False)
    panel_idx += 1

cbar = fig.colorbar(surface_meshes[-1], ax=axes.ravel().tolist(), orientation="horizontal", fraction=0.035, pad=0.04)
cbar.set_label(r"Uranium, $\log_{10}(x + 1)$")

figure_dir = Path.home() / "ckrapu.github.io/images/2026-05-24-dont-know-where-your-data-is-from"
figure_dir.mkdir(parents=True, exist_ok=True)
fig.savefig(figure_dir / "error-in-location-grid.png", dpi=220, bbox_inches="tight", facecolor=fig.get_facecolor())
联系我们 contact @ memedata.com