ROC曲线和AUC

概念

ROC全称是“受试者工作特征”(Receiver Operation Chracteristic),用来评判分类结果的好坏。

AUC(Area Under Curve)是ROC曲线下的面积。

计算方法

混淆矩阵

首先要介绍混淆矩阵的概念。在二分类中,我们用TP(true positive),FP(false positive),TN(true negative),FN(false negative)分别表示真实值为正且预测为正,真实值为负且预测为正,真实值为负且预测为负,真实值为正且预测值为负,详见下表。

预测值为1 预测值为0
真实值为1 TP FN
真实值为0 FP TN

定义

  • 真正率TPR(True Positive Rate) 即真正的正样本有多少预测为正

    $$ TPR=\frac{TP}{TP+FN} $$

  • 假正率FPR(False Positive Rate)即真正的负样本有多少预测为正

    $$ FPR=\frac{FP}{FP+TN} $$

当全部预测为0时,TP=FP=0,此时TPR=FPR=0。

当全部预测为1时,FN=TN=0,此时TPR=FPR=1。

在做二分类任务时,假设样本的实际值为

$$ y_1, y_2...y_n $$

我们的预测分值为

$$ s_1,s_2...s_n $$

设定阈值为

$$ t $$

则其对应的预测值是

$$ \{\hat{y_i}|\hat{y_i}=s_i>t,i=1,2...n\} $$

考虑到样本数量是有限的,我们只需要把样本的分值作为阈值,而不需要对所有的阈值可能值进行遍历。

因此,阈值的可能值为

$$ \{s_i|i=1,2...n\} $$

把阈值可能值从大到小排序,按照这个顺序不断调整阈值,TPR和FPR的值会从(0, 0)不断递增,直到(1, 1),这些点构成了ROC曲线,而AUC就是这个曲线下面的面积。

Python实现

下面的python的代码实现,测试与skleran的roc结果一致。

from sklearn.metrics import roc_curve, auc
from collections import Counter


def auc_function(score_list, label_list):
    a = [(s, l) for s, l in zip(score_list, label_list)]
    a = sorted(a, key=lambda x: -x[0])
    thresholds = []
    pc, nc = Counter(), Counter()
    for s, l in a:
        if len(thresholds) == 0 or thresholds[-1] != s:
            thresholds.append(s)
        if l == 1:
            pc[s] += 1
        else:
            nc[s] += 1
    n = len(score_list)
    fn = sum(label_list)
    tn = n - fn
    tp, fp = 0, 0
    fpr_list, tpr_list = [0], [0]
    fpr, tpr = 0, 0
    area = 0
    for t in thresholds:
        tp += pc[t]
        fn -= pc[t]
        fp += nc[t]
        tn -= nc[t]
        _fpr, _tpr = fp / (fp + tn), tp / (tp + fn)
        area += (_fpr - fpr) * (_tpr + tpr) / 2
        fpr, tpr = _fpr, _tpr
        fpr_list.append(fpr)
        tpr_list.append(tpr)
    return fpr_list, tpr_list, thresholds, area


if __name__ == "__main__":
    score = [0.8, 0.7, 0.5, 0.5, 0.5, 0.5, 0.3]
    label = [1, 1, 0, 0, 1, 1, 0]
    fpr, tpr, thres = roc_curve(label, score)
    area = auc(fpr, tpr)
    print(fpr, tpr, thres, area)
    print(auc_function(score, label))

cppowboy
20 声望1 粉丝

« 上一篇
字符串解析