在 seaborn 中绘制回归时如何获得数值拟合结果

新手上路,请多包涵

如果我使用 Python 中的 seaborn 库绘制线性回归的结果,有没有办法找出回归的数值结果?例如,我可能想知道拟合系数或拟合的 R 2 。

我可以使用底层的 statsmodels 界面重新运行相同的拟合,但这似乎是不必要的重复工作,而且无论如何我希望能够比较结果系数以确保数值结果与我的结果相同在情节中看到。

原文由 The Photon 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.1k
2 个回答

没有办法做到这一点。

在我看来,要求可视化库为您提供统计建模结果是倒退的。 statsmodels ,一个建模库,可让您拟合模型,然后绘制与您拟合的模型完全对应的图。如果你想要那种精确的对应关系,这种操作顺序对我来说更有意义。

你可能会说“但是 --- 中的情节没有 statsmodels seaborn 多的美学选择”。但我认为这是有道理的 — statsmodels 是一个有时在建模服务中使用可视化的建模库。 seaborn 是一个可视化库,有时在可视化服务中使用建模。专攻是好事,什么都想做是坏事。

幸运的是, seabornstatsmodels 都使用了 整洁的数据。这意味着您确实需要很少的重复工作即可通过适当的工具获得绘图和模型。

原文由 mwaskom 发布,翻译遵循 CC BY-SA 3.0 许可协议

不幸的是,Seaborn 的创建者声明 他不会添加这样的功能。下面是一些选项。 (最后一节包含我最初的建议,这是一个使用 seaborn 的私有实现细节的 hack,并且不是特别灵活。)

regplot

以下函数在散点图上叠加拟合线并返回 statsmodels 的结果。这支持 sns.regplot 的最简单且可能是最常见的用法,但不实现任何更高级的功能。

 import statsmodels.api as sm

def simple_regplot(
    x, y, n_std=2, n_pts=100, ax=None, scatter_kws=None, line_kws=None, ci_kws=None
):
    """ Draw a regression line with error interval. """
    ax = plt.gca() if ax is None else ax

    # calculate best-fit line and interval
    x_fit = sm.add_constant(x)
    fit_results = sm.OLS(y, x_fit).fit()

    eval_x = sm.add_constant(np.linspace(np.min(x), np.max(x), n_pts))
    pred = fit_results.get_prediction(eval_x)

    # draw the fit line and error interval
    ci_kws = {} if ci_kws is None else ci_kws
    ax.fill_between(
        eval_x[:, 1],
        pred.predicted_mean - n_std * pred.se_mean,
        pred.predicted_mean + n_std * pred.se_mean,
        alpha=0.5,
        **ci_kws,
    )
    line_kws = {} if line_kws is None else line_kws
    h = ax.plot(eval_x[:, 1], pred.predicted_mean, **line_kws)

    # draw the scatterplot
    scatter_kws = {} if scatter_kws is None else scatter_kws
    ax.scatter(x, y, c=h[0].get_color(), **scatter_kws)

    return fit_results

statsmodels 的结果包含大量信息, _例如_:

 >>> print(fit_results.summary())

                            OLS Regression Results
==============================================================================
Dep. Variable:                      y   R-squared:                       0.477
Model:                            OLS   Adj. R-squared:                  0.471
Method:                 Least Squares   F-statistic:                     89.23
Date:                Fri, 08 Jan 2021   Prob (F-statistic):           1.93e-15
Time:                        17:56:00   Log-Likelihood:                -137.94
No. Observations:                 100   AIC:                             279.9
Df Residuals:                      98   BIC:                             285.1
Df Model:                           1
Covariance Type:            nonrobust
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         -0.1417      0.193     -0.735      0.464      -0.524       0.241
x1             3.1456      0.333      9.446      0.000       2.485       3.806
==============================================================================
Omnibus:                        2.200   Durbin-Watson:                   1.777
Prob(Omnibus):                  0.333   Jarque-Bera (JB):                1.518
Skew:                          -0.002   Prob(JB):                        0.468
Kurtosis:                       2.396   Cond. No.                         4.35
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

sns.regplot 的替代品(几乎)

上述方法优于我在下面的原始答案的优点是很容易将其扩展到更复杂的配合。

无耻的插件:这是我编写的这样一个扩展的 regplot 函数,它实现了 sns.regplot 的大部分功能: https ://github.com/ttesileanu/pydove。

虽然仍然缺少一些功能,但我编写的功能

  • 通过将绘图与统计建模分开来提供灵活性(并且您还可以轻松访问拟合结果)。
  • 对于大型数据集来说要快得多,因为它让 statsmodels 计算置信区间而不是使用自举。
  • 允许稍微多样化的拟合( 例如, 多项式在 log(x) )。
  • 允许稍微更细粒度的绘图选项。

旧答案

不幸的是,Seaborn 的创建者声明 他不会添加这样的功能,所以这里有一个解决方法。

 def regplot(
    *args,
    line_kws=None,
    marker=None,
    scatter_kws=None,
    **kwargs
):
    # this is the class that `sns.regplot` uses
    plotter = sns.regression._RegressionPlotter(*args, **kwargs)

    # this is essentially the code from `sns.regplot`
    ax = kwargs.get("ax", None)
    if ax is None:
        ax = plt.gca()

    scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
    scatter_kws["marker"] = marker
    line_kws = {} if line_kws is None else copy.copy(line_kws)

    plotter.plot(ax, scatter_kws, line_kws)

    # unfortunately the regression results aren't stored, so we rerun
    grid, yhat, err_bands = plotter.fit_regression(plt.gca())

    # also unfortunately, this doesn't return the parameters, so we infer them
    slope = (yhat[-1] - yhat[0]) / (grid[-1] - grid[0])
    intercept = yhat[0] - slope * grid[0]
    return slope, intercept

请注意,这仅适用于线性回归,因为它只是从回归结果中推断出斜率和截距。好的是它使用 seaborn 自己的回归类,因此保证结果与显示的一致。缺点当然是我们在 seaborn 中使用了一个私有实现细节,它随时都可能中断。

原文由 Legendre17 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题