ID3(来自西瓜书)
ID3算法使用信息增益的大小选择特征。信息增益 = 信息熵 - 条件熵
信息熵
例如西瓜书中的数据集,数据集有两个类别:好瓜
, 坏瓜
,即\( k=2 \)。
则信息熵计算公式为:
条件熵
针对特征A的条件熵定义如下:
公式中\( m \)为特征离散值的种类,比如所特征色泽
为例,特征取值{青绿,乌黑,浅自},即\( m=3 \)。
信息增益
信息增益 = 信息熵 - 条件熵
code
#encoding=utf-8
import math
import json
def compute_ent(data):
"""
计算信息熵
"""
cdict = {}
num = len(data)
for e in data:
if e[-1] in cdict:
cdict[e[-1]] += 1
else:
cdict[e[-1]] = 1
res = 0
for key, value in cdict.items():
res += - (value / num) * math.log2(value / num)
return res
def split_data_by_feat(data, feat, pos):
ndata = []
for e in data:
if e[pos] == feat:
ndata.append(e[:pos] + e[pos+1:])
return ndata
def split_features(data, num_feat):
"""
寻找最佳划分特征,循环便利每一个特征,选取划分后信息增益最大的特征
"""
ent = compute_ent(data)
pos = -1
best = -1
# 遍历所有特征
for i in range(num_feat):
con_ent = 0
feats = set(e[i] for e in data)
# 计算特征条件熵
for sub in feats:
ndata = split_data_by_feat(data, sub, i)
con_ent += len(ndata) / len(data) * compute_ent(ndata)
# 计算信息增益
gain = ent - con_ent
if best < gain:
best = gain
pos = i
return pos
def count(data):
cdict = {}
for e in data:
if e[0] in cdict:
cdict[e[0]] += 1
else:
cdict[e[0]] = 1
cls = ''
ccount = 0
for key, value in cdict.items():
if value > ccount:
cls = key
return cls
def ID3Tree(data, names):
classes = [e[-1] for e in data]
# 如果划分后的数据集中只有一个类别,直接返回
if len(set(classes)) == 1:
return classes[0]
# 如果按照特征值划分后数据集中有多个类别,则按照类别较多的样本定义为划分类别结果
if len(data[0]) == 1 or len(names) == 0:
return count(data)
# 找到最佳的特征
pos = split_features(data, len(names))
feat = names[pos]
tree = {feat:{}}
del(names[pos])
feat_values = set([e[pos] for e in data])
# 按照特征值划分子树
for value in feat_values:
ndata, subnames = split_data_by_feat(data, value, pos), names[:]
tree[feat][value] = ID3Tree(ndata, subnames)
return tree
if __name__ == '__main__':
names = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
data = [\
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],\
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],\
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],\
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],\
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],\
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],\
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'],\
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'],\
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'],\
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'],\
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '否'],\
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '否'],\
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'],\
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'],\
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],\
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'],\
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否',]]
print(json.dumps(ID3Tree(data, names), indent=1, ensure_ascii=False))
缺点
- 没有减枝操作(过拟合)
- 无法处理连续特征值,缺失特征值
- 偏好特征值多的特征(西瓜书中的编号特征)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。