机器学习KNN算法——鸢尾花案例
本文最后更新于-1天前,其中的信息可能已经过时,如有错误请发送邮件到2392862431@qq.com

代码

这里的代码来自黑马b站课程,我这里也只是记录一下,方便我后续自己复习,我也是借助AI帮我解释了。

def knn_iris_gscv():
    """
    用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
    :return:
    """
    # 1)获取数据
    iris = load_iris()

    # 2)划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)

    # 3)特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4)KNN算法预估器
    estimator = KNeighborsClassifier()

    # 加入网格搜索与交叉验证
    # 参数准备
    param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10) # 验证集
    estimator.fit(x_train, y_train)

    # 5)模型评估
    # 方法1:直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接比对真实值和预测值:\n", y_test == y_predict)

    # 方法2:计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率为:\n", score)

    # 最佳参数:bestparams
    print("最佳参数:\n", estimator.bestparams)
    # 最佳结果:bestscore
    print("最佳结果:\n", estimator.bestscore)
    # 最佳估计器:bestestimator
    print("最佳估计器:\n", estimator.bestestimator)
    # 交叉验证结果:cvresults
    print("交叉验证结果:\n", estimator.cvresults)

    return None

整体思路

这个代码就像是在训练一个”植物专家”,让它学会根据花的特征来判断是哪种鸢尾花。

逐步解析

1. 获取数据 – “准备学习材料”

iris = load_iris()

就像给学生准备教科书一样,这里加载了经典的鸢尾花数据集。这个数据集包含150朵花的信息,每朵花有4个特征(花瓣长度、宽度等)和1个标签(花的种类)。

2. 划分数据集 – “分配练习题和考试题”

x_train, x_test, y_train, y_test = train_test_split(...)

就像学习时要把题目分成”练习题”和”考试题”:

  • 练习题(训练集):让模型学习用的
  • 考试题(测试集):最后检验模型学得怎么样

3. 特征工程 – “统一度量标准”

transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

想象你要比较身高和体重,一个是厘米,一个是公斤,单位不同没法直接比较。标准化就是把所有特征都转换成相同的尺度,让模型能公平地对待每个特征。

4. KNN算法 – “找最相似的邻居”

estimator = KNeighborsClassifier()

KNN就像是”问邻居”的方法:

  • 遇到一朵新花,就找最相似的K个邻居
  • 看这些邻居大多数是什么类型,新花就判断为什么类型
  • 比如K=3,找到3个最相似的花,如果2个是A类,1个是B类,那新花就是A类

5. 网格搜索 – “找最佳设置”

param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)

就像调试相机参数一样:

  • 我们不知道K值(邻居数量)设多少最好
  • 所以把可能的值都试一遍:1个邻居、3个邻居、5个邻居…
  • 看哪个设置拍出的照片最清楚(准确率最高)

6. 交叉验证 – “多次考试求平均”

cv=10

就像一个学生考试:

  • 不能只考一次就说他水平如何
  • 要考10次,取平均分才靠谱
  • 交叉验证就是把训练数据分成10份,轮流做9份训练、1份验证

7. 模型评估 – “检验学习效果”

y_predict = estimator.predict(x_test)
score = estimator.score(x_test, y_test)

就像期末考试:

  • 用从没见过的测试数据来检验模型
  • 看预测结果和真实答案有多少一致
  • 计算准确率

核心原理

KNN的工作原理

  1. 计算新样本与所有训练样本的距离
  2. 找出距离最近的K个邻居
  3. 统计这K个邻居中哪个类别最多
  4. 把新样本分类为最多的那个类别

为什么要网格搜索

  • K值太小:容易受噪声影响(比如K=1,只看1个邻居)
  • K值太大:可能忽略局部特征
  • 网格搜索自动找到最合适的K值

为什么要交叉验证

  • 避免”运气好”的情况
  • 确保模型的稳定性
  • 更准确地评估模型性能

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇