本文最后更新于-1天前,其中的信息可能已经过时,如有错误请发送邮件到2392862431@qq.com
代码
这里的代码来自黑马b站课程,我这里也只是记录一下,方便我后续自己复习,我也是借助AI帮我解释了。
def knn_iris_gscv():
"""
用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
:return:
"""
# 1)获取数据
iris = load_iris()
# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)
# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4)KNN算法预估器
estimator = KNeighborsClassifier()
# 加入网格搜索与交叉验证
# 参数准备
param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10) # 验证集
estimator.fit(x_train, y_train)
# 5)模型评估
# 方法1:直接比对真实值和预测值
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)
# 方法2:计算准确率
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)
# 最佳参数:bestparams
print("最佳参数:\n", estimator.bestparams)
# 最佳结果:bestscore
print("最佳结果:\n", estimator.bestscore)
# 最佳估计器:bestestimator
print("最佳估计器:\n", estimator.bestestimator)
# 交叉验证结果:cvresults
print("交叉验证结果:\n", estimator.cvresults)
return None
整体思路
这个代码就像是在训练一个”植物专家”,让它学会根据花的特征来判断是哪种鸢尾花。
逐步解析
1. 获取数据 – “准备学习材料”
iris = load_iris()
就像给学生准备教科书一样,这里加载了经典的鸢尾花数据集。这个数据集包含150朵花的信息,每朵花有4个特征(花瓣长度、宽度等)和1个标签(花的种类)。
2. 划分数据集 – “分配练习题和考试题”
x_train, x_test, y_train, y_test = train_test_split(...)
就像学习时要把题目分成”练习题”和”考试题”:
- 练习题(训练集):让模型学习用的
- 考试题(测试集):最后检验模型学得怎么样
3. 特征工程 – “统一度量标准”
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
想象你要比较身高和体重,一个是厘米,一个是公斤,单位不同没法直接比较。标准化就是把所有特征都转换成相同的尺度,让模型能公平地对待每个特征。
4. KNN算法 – “找最相似的邻居”
estimator = KNeighborsClassifier()
KNN就像是”问邻居”的方法:
- 遇到一朵新花,就找最相似的K个邻居
- 看这些邻居大多数是什么类型,新花就判断为什么类型
- 比如K=3,找到3个最相似的花,如果2个是A类,1个是B类,那新花就是A类
5. 网格搜索 – “找最佳设置”
param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
就像调试相机参数一样:
- 我们不知道K值(邻居数量)设多少最好
- 所以把可能的值都试一遍:1个邻居、3个邻居、5个邻居…
- 看哪个设置拍出的照片最清楚(准确率最高)
6. 交叉验证 – “多次考试求平均”
cv=10
就像一个学生考试:
- 不能只考一次就说他水平如何
- 要考10次,取平均分才靠谱
- 交叉验证就是把训练数据分成10份,轮流做9份训练、1份验证
7. 模型评估 – “检验学习效果”
y_predict = estimator.predict(x_test)
score = estimator.score(x_test, y_test)
就像期末考试:
- 用从没见过的测试数据来检验模型
- 看预测结果和真实答案有多少一致
- 计算准确率
核心原理
KNN的工作原理:
- 计算新样本与所有训练样本的距离
- 找出距离最近的K个邻居
- 统计这K个邻居中哪个类别最多
- 把新样本分类为最多的那个类别
为什么要网格搜索:
- K值太小:容易受噪声影响(比如K=1,只看1个邻居)
- K值太大:可能忽略局部特征
- 网格搜索自动找到最合适的K值
为什么要交叉验证:
- 避免”运气好”的情况
- 确保模型的稳定性
- 更准确地评估模型性能
