决策树
决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一颗熵值下降最快的树,到叶子节点处,熵值为0
具有非常好的可解释性、分类速度快的优点,是一种有监督学习
最早提及决策树思想的是Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法,以及Breiman等人在1984年提出的CART算法
工作原理
一般的,一颗决策树包含一个根结点、若干个内部节点和若干个叶节点
构造
构造就是生成一棵完整的决策树。简单来说,构造的过程就是选择什么属性作为节点的过程
叶结点对应于决策结果, 其他每个结点则对应于一个属性测试,每个结点包含的样本集合根据属性测试的结果被划分到子结点中;
根结点包含样本全集,从根结点到每个叶结点的路径对应了一个判定测试序列. 决策树学习的目的是为了产生一棵泛化能力强, 即处理未见示例能力强的决策树,其基本流程遵循简单且直观的分而治之策略
显然, 决策树的生成是一个递归过程. 在决策树基本算法中, 有三种情形会导致递归返回:
- 当前结点包含的样本全属于同一类别, 无需划分
- 当前属性集为空, 或是所有样本在所有属性上取值相同, 无法划分
- 当前结点包含的样本集合为空, 不能划分
划分选择
决策树学习的关键是如何选择最优划分属性
随着划分过程不断进行, 我们希望决策树的分支结点所包含的样本尽可能属于同一类别, 即结点的“纯度”越来越高
信息熵
信息熵在信息论中代表随机变量不确定度的度量
熵越大,数据的不确定性越高,纯度越低
熵越小,数据的不确定性越低,纯度越高
假定当前样本集合 $D$ 中第 $k$ 类样本所占的比例为 $p_{k}(k=1,2, \ldots,|\mathcal{Y}|)$ 则$D$的信息嫡定义为
$$\operatorname{Ent}(D)=-\sum_{k=1}^{|\mathcal{Y}|} p_{k} \log _{2} p_{k}$$
信息增益
假定离散属性 $a$ 有 $V$ 个可能的取值 $\left\{a^{1}, a^{2}, \ldots, a^{V}\right\},$ 若使用 $a$ 来对样本集 $D$ 进行划分, 则会产生 $V$ 个分支结点, 其中第 $v$ 个分支结点包含了 $D$ 中所有在 属性 $a$ 上取值为 $a^{v}$ 的样本, 记为 $D^{v} .$ 我们可根据式 (4.1) 计算出 $D^{v}$ 的信息嫡, 再考虑到不同的分支结点所包含的样本数不同, 给分支结点赋予权重 $\left|D^{v}\right| /|D|,$ 即样本数越多的分支结点的影响越大, 于是可计算出用属性 $a$ 对样本集 $D$ 进行 划分所获得的“信息增益”
$$\operatorname{Gain}(D, a)=\operatorname{Ent}(D)-\sum_{v=1}^{V} \frac{\left|D^{v}\right|}{|D|} \operatorname{Ent}\left(D^{v}\right)$$
一般而言, 信息增益越大, 则意味着使用属性 $a$ 来进行划分所获得的“纯度提升”越大. 因此, 我们可用信息增益来进行决策树的划分属性选择, 著名的 ID3 决策树学习算法就是以信息增益为准则来选择划分属性
ID3 算法的优点是方法简单,缺点是对噪声敏感。训练数据如果有少量错误,可能会产生决策树分类错误
信息增益率
$$ \text {GainRatio}(D, a)=\frac{\operatorname{Gain}(D, a)}{Ent(D)} $$
C4.5 在 IID3 的基础上,用信息增益率代替了信息增益,解决了噪声敏感的问题,并且可以对构造树进行剪枝、处理连续数值以及数值缺失等情况,但是由于 C4.5 需要对数据集进行多次扫描,算法效率相对较低
基尼指数
基尼指数是经典决策树CART用于分类问题时选择最优特征的指标
假设有$K$个类,样本点属于第$k$类的概率为$p_k$,则概率分布的基尼指数定义为
$$ G(p)=\sum_{k=1}^{K} p_{k}\left(1-p_{k}\right)=1-\sum_{k=1}^{K} p_{k}^{2} $$
在信息增益、增益率、基尼指数之外, 人们还设计了许多其他的准则用于决策树划分选择
然而有实验研究表明这些准则虽然对决策树的尺寸有较大影响, 但对泛化性能的影响很有限
剪枝
决策树的缺点包括对未知的测试数据未必有好的分类、泛化能力,即可能发生过拟合现象,此时可采用剪枝或随机森林
剪枝是决策树学习算法对付“过拟合”的主要手段
在决策树学习中, 为了尽可能正确分类训练样本, 结点划分过程将不断重复, 有时会造成决 策树分支过多, 这时就可能因训练样本学得太好了, 以致于把训练集自身 的一些特点当作所有数据都具有的一般性质而导致过拟合
ID3 决策树代码
参考 Machine Learning in Action by Peter Harrington
# coding = utf-8
from math import log
import numpy as np
from collections import Counter
class DecisionTree:
"""ID3 DecisionTree
"""
def __init__(self):
self.decisionTree = None
self._X = None
self._y = None
# 计算信息熵
def calcShannonEnt(self,y):
lablesCounter = Counter(y)
shannonEnt = 0.0
for num in lablesCounter.values():
p = num / len(y)
shannonEnt += -p * log(p,2)
return shannonEnt
def fit(self, X, y):
self._X = X
self._y = y
self.decisionTree = self.createTree(self._X,self._y)
return self
def splitDataset(self,X,y,d,value):
features = X[X[:,d]==value]
labels = y[X[:,d]==value]
return np.concatenate((features[:,:d],features[:,d+1:]),axis=1), labels
def chooseBestFeatureToSplit(self,X,y):
numFeatures = X.shape[1]
baseEntropy = self.calcShannonEnt(y)
bestInfoGain, bestFeature = 0.0, -1
for i in range(numFeatures):
# 创建唯一的分类标签列表
uniqueVals = np.unique(X[:,i])
newEntropy =0.0
# 计算每种划分方式的信息熵
for value in uniqueVals:
_x, _y = self.splitDataset(X,y,i,value)
prob = len(_x)/len(X)
newEntropy += prob * self.calcShannonEnt(_y)
infoGain = baseEntropy - newEntropy
if infoGain>bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(self,y):
lablesCounter = Counter(y)
return lablesCounter.most_common(1)[0]
def createTree(self,X,y):
# 类别完全相同则停止继续划分
if y[y == y[0]].size == y.size :
return y[0]
# 遍历完所有特征时返回出现次数最多的类别
if X.shape[1] == 0:
return self.majorityCnt(y)
bestFeat = self.chooseBestFeatureToSplit(X,y)
decisionTree = {bestFeat: {}}
for value in np.unique(X[:,bestFeat]):
decisionTree[bestFeat][value] = self.createTree(*self.splitDataset(X,y,bestFeat, value))
return decisionTree
if __name__ == '__main__':
dataSet = np.array([[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']])
labels = ['no surfacing', 'flippers']
dt = DecisionTree()
X = dataSet[:, :2]
X = X.astype(np.int)
y = dataSet[:,-1]
dt.fit(X,y)
print(dt.decisionTree)
参考
机器学习-周志华
Machine Learning in Action by Peter Harrington
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。