实验内容简介
使用 ID3 算法对数据进行决策树的构建实现分类
算法说明
ID3 算法本质上是使用信息增益作为树分区划分的属性度量条件,信息增益指在系统中有无该属性时带来的信息量的变化,具体计算为系统熵减去某个属性的条件熵。如果某个属性具有最高的信息增益,那么就选择它作为分裂特征。然后递归这个过程,从而生成决策树。
以上面的数据为例。如果我们想通过询问 age/income/student/credits_rating 来确定 buy_computer 的结果,我们首先算出 buy_computer 的信息熵,然后分别计算固定 age 等特征下的条件熵(分条件下的信息熵的加权平均)。通过总的系统熵减去条件熵计算出最大的特征作为这一次的决策特征,如果可以分类出来结果了,就直接放上叶子节点,如果没有的话就递归下去继续分类;如果分不出来结果了,属性也用完了,这个时候一定是出现两个以上其他属性一模一样而分类的变量不一样的情况,取最多的那种。
算法分析与设计
这个文件实现了 ID3 算法(Iterative Dichotomiser 3),用于生成决策树。在该实现中,算法的核心思想是通过计算信息增益(Information Gain),选择最优特征逐步分割数据集,最终生成分类决策树。
以下是对代码中各个部分的分析:
1. load_data_from_file
函数
功能:
- 从 JSON 文件中加载数据集,返回字典形式的数据。
- 这是实现的输入接口,方便用户使用外部文件提供数据。
代码分析:
def load_data_from_file(file: str) -> dict:
with open(file, 'r') as f:
data = json.load(f)
return data
- 通过
json.load
加载数据,文件格式需严格符合 JSON 规范。 - 输出是一个包含数据集的 Python 字典。
2. entropy
函数
功能:
- 计算数据集的熵,用于衡量数据的混乱程度或不确定性。
- 支持两种情况:整个数据集的熵(无条件)和按条件划分后的加权熵。
代码分析:
def entropy(data: list, target: str, condition: str = None) -> float:
if condition:
values = set([d[condition] for d in data])
ent = 0
for v in values:
sub_data = [d for d in data if d[condition] == v]
ent += len(sub_data) / len(data) * entropy(sub_data, target)
return ent
else:
values = set([d[target] for d in data])
ent = 0
for v in values:
p = len([d for d in data if d[target] == v]) / len(data)
ent -= p * math.log2(p)
return ent
-
无条件熵:
- 根据目标值的分布计算整体熵。
- 使用公式 。
-
条件熵:
- 按条件特征(
condition
)对数据分组,然后计算加权熵:
- 按条件特征(
3. select_best_feature
函数
功能:
- 根据信息增益公式,从候选特征中选择最优特征。
- 信息增益公式:。
代码分析:
def select_best_feature(data: list, target: str, features: list) -> str:
ent = entropy(data, target)
gains = {f: ent - entropy(data, target, f) for f in features}
return max(gains, key=gains.get)
- 遍历所有特征,计算每个特征的条件熵和信息增益。
- 返回信息增益最高的特征。
4. build_tree
函数
功能:
- 构建决策树的主入口,使用递归方法逐步添加决策节点。
代码分析:
def build_tree(data: list, target: str, features: list) -> anytree.Node:
root = anytree.Node("root")
add_decision_node(root, data, target, features)
return root
- 决策树的根节点初始化为
"root"
。 - 调用辅助函数
add_decision_node
添加分支和叶子节点。
5. add_decision_node
函数
功能:
- 递归地将节点添加到决策树中,完成分裂或分类。
代码分析:
def add_decision_node(node: anytree.Node, data: list, target: str, features: list):
# if no data left
if not data:
anytree.Node("Unknown", parent=node)
return
# if all data has the same target value
target_list = [d[target] for d in data]
target_set = set(target_list)
if len(target_set) == 1:
anytree.Node(target_list[0], parent=node)
return
# if there is no feature left but the target values are not the same
if len(features) == 0:
# choose the most common target value
target_value = max(set(target_list), key=target_list.count)
anytree.Node(target_value, parent=node)
return
# choose the best feature, node name is the feature_name : feature_value
local_features = features[:]
best_feature = select_best_feature(data, target, local_features)
local_features.remove(best_feature)
def group_by(data, feature):
grouped = defaultdict(list)
for d in data:
grouped[d[feature]].append(d)
return grouped
feature_groups = group_by(data, best_feature)
for feature_value, sub_data in feature_groups.items():
child = anytree.Node(f"{best_feature} : {feature_value}", parent=node)
add_decision_node(child, sub_data, target, local_features)
- 终止条件:
- 数据集中所有目标值相同:创建叶子节点并返回。
- 特征耗尽但目标值不同:选择目标值中出现次数最多的作为叶子节点。
- 递归构建:
- 选择最佳特征,按该特征的每个取值分组。
- 为每组创建子节点并递归调用。
主要算法逻辑
- 计算数据的无条件熵。
- 遍历所有特征,计算条件熵与信息增益。
- 按信息增益选择分裂特征,并对数据按特征值分组。
- 对每个分组递归进行分裂,直到满足终止条件。
测试结果
测试代码
from id3 import *
def test_load_data_features():
table = load_data_from_file('data.json')
assert table['features'] == ["age", "income", "student", "credit_rating"]
data = table['data']
assert data[0]["age"] == "youth"
assert table["target"] == "buys_computer"
def test_entropy():
table = load_data_from_file('data.json')
data = table['data']
target = table['target']
assert abs(entropy(data=data, target=target) - 0.94) < 0.01
def test_entropy_condition():
table = load_data_from_file('data.json')
data = table['data']
target = table['target']
condition = 'age'
assert abs(entropy(data=data, target=target,
condition=condition) - 0.69) < 0.01
def test_select_best_feature():
table = load_data_from_file('data.json')
data = table['data']
target = table['target']
features = table['features']
assert select_best_feature(data, target, features) == 'age'
def test_build_tree():
table = load_data_from_file('data.json')
data = table['data']
target = table['target']
features = table['features']
tree = build_tree(data, target, features)
print(anytree.RenderTree(tree).by_attr())
assert tree.name == 'root'
def get_children(node):
return set([child.name for child in node.children])
assert get_children(tree) == {'age : youth',
'age : middle_aged', 'age : senior'}
➜ id3 pytest
============== test session starts ===============
platform darwin -- Python 3.9.6, pytest-8.3.3, pluggy-1.5.0
rootdir: /Users/cyril/Dev/school/algo/id3
collected 5 items
test_id3.py ..... [100%]
=============== 5 passed in 0.01s ================
➜ id3 python -u id3.py
root
├── age : middle_aged
│ └── yes
├── age : youth
│ ├── student : yes
│ │ └── yes
│ └── student : no
│ └── no
└── age : senior
├── credit_rating : fair
│ └── yes
└── credit_rating : excellent
└── no
分析与探讨
核心技术细节
-
ID3 算法
- 该实现基于信息增益准则,逐层递归地选择最优特征。
- 信息增益的计算与理论一致。
-
递归构建树
- 函数式编程风格下的递归调用简洁清晰。
- 子节点的分组和递归传递特征列表,避免了全局污染。
-
异常情况处理
- 当特征耗尽但目标值不一致时,通过选择最常见的目标值作为结果,保证树的生成。
- 对空数据集或无法确定的子节点赋予默认值(
"Unknown"
),避免报错。
-
可视化
- 借助
anytree
实现了简洁的树形结构渲染,便于调试和结果展示。
- 借助
可改进点
-
效率优化
- 当前代码在计算熵和分组时存在重复操作,可通过缓存子集结果优化效率。
- 可利用 NumPy 等库加速熵和分组计算。
-
功能扩展
- 支持连续变量的处理(如通过信息增益率计算分裂点)。
- 引入剪枝机制(如预剪枝或后剪枝)避免过拟合。
-
输出结果改进
- 决策树的结果可进一步序列化(如导出为 JSON 格式)以便于部署。
-
异常处理
- 增加对文件格式异常的检查,防止 JSON 格式错误导致程序崩溃。
源代码
https://github.com/Satar07/id3
# python implementation of ID3 algorithm
from collections import defaultdict
import json
import math
import anytree # pip install anytree
def load_data_from_file(file: str) -> dict:
with open(file, 'r') as f:
data = json.load(f)
return data
def entropy(data: list, target: str, condition: str = None) -> float:
if condition:
values = set([d[condition] for d in data])
ent = 0
for v in values:
sub_data = [d for d in data if d[condition] == v]
ent += len(sub_data) / len(data) * entropy(sub_data, target)
return ent
else:
values = set([d[target] for d in data])
ent = 0
for v in values:
p = len([d for d in data if d[target] == v]) / len(data)
ent -= p * math.log2(p)
return ent
def select_best_feature(data: list, target: str, features: list) -> str:
ent = entropy(data, target)
gains = {f: ent - entropy(data, target, f) for f in features}
return max(gains, key=gains.get)
def build_tree(data: list, target: str, features: list) -> anytree.Node:
root = anytree.Node("root")
add_decision_node(root, data, target, features)
return root
def add_decision_node(node: anytree.Node, data: list, target: str, features: list):
# if no data left
if not data:
anytree.Node("Unknown", parent=node)
return
# if all data has the same target value
target_list = [d[target] for d in data]
target_set = set(target_list)
if len(target_set) == 1:
anytree.Node(target_list[0], parent=node)
return
# if there is no feature left but the target values are not the same
if len(features) == 0:
# choose the most common target value
target_value = max(set(target_list), key=target_list.count)
anytree.Node(target_value, parent=node)
return
# choose the best feature, node name is the feature_name : feature_value
local_features = features[:]
best_feature = select_best_feature(data, target, local_features)
local_features.remove(best_feature)
def group_by(data, feature):
grouped = defaultdict(list)
for d in data:
grouped[d[feature]].append(d)
return grouped
feature_groups = group_by(data, best_feature)
for feature_value, sub_data in feature_groups.items():
child = anytree.Node(f"{best_feature} : {feature_value}", parent=node)
add_decision_node(child, sub_data, target, local_features)
if __name__ == '__main__':
table = load_data_from_file('data.json')
data = table['data']
target = table['target']
features = table['features']
tree = build_tree(data, target, features)
print(anytree.RenderTree(tree).by_attr())