【笔记】混淆矩阵,精准率和召回率

作者:神秘网友 发布时间:2021-01-25 13:38:05

【笔记】混淆矩阵,精准率和召回率

混淆矩阵,精准率和召回率

评论回归算法的好坏点击这里

评价分类算法是不能单单靠一个分类准确度就可以衡量的,单用一个分类准确度是有问题的

比如说,一个癌症预测系统,输入体检信息,就可以判断是否得了癌症,这个系统的预测准确率有99.9%,但是不能说这个系统就是好的,因为如果患有癌症的概率是0.1%,那么即使预测所有人都是健康的,也可以达到99.9%的准确率,这样就发现,这个系统一点用没有

这种情况可以称为数据极度偏斜,对于极度偏斜的数据,使用分类准确度来评定的话,可以发现其准确度是非常高的,但是有可能其实算法是不太行的,因此对于这种数据,只使用分类准确度是远远不够的

可以使用比较基础的混淆矩阵来做进一步的分析

混淆矩阵

对于二分类问题,混淆矩阵实际上就是一个2*2的矩阵,其还需添加一行一列作为内容的标记,其中行代表真实值,列代表预测值,行相当于是这个二维数组的第一个维度,列相当于第二个维度,一般设为0和1,其中0和1的意思看分析的问题对应设置,设0位阴性,1位阳性,则在0,0的位置为预测阴性正确TN,0,1的位置为预测阳性错误FP,1,0的位置为预测阴性错误FN,1,1的位置为预测阴性正确TP

这就是混淆矩阵,是在分类任务中的一个重要的工具,可以更好的的得到分类算法的好坏,其中,有两个很重要的通过混淆矩阵才能得到的指标,精准率和召回率

精准率和召回率

精准率的公式,其就是预测为1且预测正确的概率

召回率的公式,其就是真实为1中的预测为1的比例,即真实的发生的事件中成功预测的概率

为什么说精准率和召回率是比分类准确度更好的指标,因为对于一些没有意义的模型可以很好的分辨出来

那么可以具体实现一下

实现混淆矩阵,精准率和召回率

(在notebook中)

使用手写识别数据集,由于需要使用的极度偏斜的数据,那么就需要设置内容的条件,即9的时候为1,非9位0,然后对数据集进行分割

  import numpy as np
  from sklearn import datasets

  digits = datasets.load_digits()
  X = digits.data
  y = digits.target.copy()

  y[digits.target==9] = 1
  y[digits.target!=9] = 0

  from sklearn.model_selection import train_test_split
  X_train,X_test,y_train,y_test =      train_test_split(X,y,random_state=666)

使用逻辑回归,并看一下表现如何

  from sklearn.linear_model import LogisticRegression

  log_reg = LogisticRegression()
  log_reg.fit(X_train,y_train)
  log_reg.score(X_test,y_test)

结果如下

由于是极度偏斜的数据,所以要考察一下其他的性能指标,首先得到预测值以后,就开始求TN,FP,TP,FP,求解代码如下

  y_log_predict = log_reg.predict(X_test)

  def TN(y_true,y_predict):
      assert len(y_true) == len(y_predict)
      return np.sum((y_true == 0)&(y_predict == 0))

  TN(y_test,y_log_predict)

结果如下

  def FP(y_true,y_predict):
      assert len(y_true) == len(y_predict)
      return np.sum((y_true == 0)&(y_predict == 1))

  FP(y_test,y_log_predict)

结果如下

  def FN(y_true,y_predict):
      assert len(y_true) == len(y_predict)
      return np.sum((y_true == 1)&(y_predict == 0))

  FN(y_test,y_log_predict)

结果如下

  def TP(y_true,y_predict):
      assert len(y_true) == len(y_predict)
      return np.sum((y_true == 1)&(y_predict == 1))

  TP(y_test,y_log_predict)

结果如下

可以尝试得出混淆矩阵的内容

  def confusion_matrix(y_true,y_predict):
      return np.array([
          [TN(y_true,y_predict),FP(y_true,y_predict)],
          [FN(y_true,y_predict),TP(y_true,y_predict)]
      ])

  confusion_matrix(y_test,y_log_predict)

结果如下

精准率的求解代码,使用先前的公式即可

  def precision_score(y_true,y_predict):
      tp = TP(y_true,y_predict)
      fp = FP(y_true,y_predict)
      try:
          return tp / (tp+fp)
      except:
          return 0.0

  precision_score(y_test,y_log_predict)

结果如下

召回率的求解代码,使用先前的公式即可

  def recall_score(y_true,y_predict):
      tp = TP(y_true,y_predict)
      fn = FN(y_true,y_predict)
      try:
          return tp / (tp+fn)
      except:
          return 0.0

  recall_score(y_test,y_log_predict)

结果如下

在sklearn中的混淆矩阵以及精准率和召回率

使用confusion_matrix即可调用出sklearn中的混淆矩阵,使用和上面一样

  from sklearn.metrics import confusion_matrix

  confusion_matrix(y_test,y_log_predict)

结果如下

使用precision_score即可调用出sklearn中的精准率,使用和上面一样

  from sklearn.metrics import precision_score

  precision_score(y_test,y_log_predict)

结果如下

使用recall_score即可调用出sklearn中的召回率,使用和上面一样

  from sklearn.metrics import recall_score

  recall_score(y_test,y_log_predict)

结果如下

以上就是混淆矩阵,精准率以及召回率的概念公式和实现的过程以及sklearn中的类的调用

【笔记】混淆矩阵,精准率和召回率 相关文章

  1. 替罪羊树学习笔记

    Part 0 引子 我们都知道,有一种东西叫 BST。 我们都知道,BST 在极限数据会卡爆。 我们都知道,为了让 BST 不被卡,有很多种平衡树。 但你知道有一种平衡树好写速度快吗那就是替罪羊树。 Part 1 替罪羊树平衡的原理 替罪羊树是一种平衡树,一种平衡的 BST。

  2. Python深度学习笔记08--处理文本数据的常用方法

    6.1 处理文本数据 6.1.1 单词和字符的one-hot编码 (1)单词级的one-hot编码: 1 # 单词级的one-hot编码 2 import numpy as np 3 4 # 初始数据:每个样本是列表的一个元素(本例中的样本是一个句子,但也可以是一整篇文档) 5 samples = ['The cat sat on the ma

  3. SAS初学者笔记---004---循环结构与判断结构

    关于循环与判断的语句在所有程序设计中十分重要,在SAS程序中也不例外。逻辑清晰的循环与判断结构是日后进行数据清洗、数据构造的必要前提。(反正就是很重要就对了) 循环结构 关于循环结构,常见的有三种类型 DO Index. 索引循环,此语句是DO循环语句中最

  4. Numpy中的矩阵和向量

    1. 使用Numpy构造矩阵 例如:[[1,2,3], [4,5,6]] 我们可以这样做: A = np.array([[1,2,3],[4,5,6]]) 2. 构造向量,向量可以分为行向量和列向量 构建列向量: B = np.array([[2],[1],[3]]) 使用这个方法可以将其转置为行向量 B = np.transpose(np.array([[2,1

  5. JUnit学习笔记

    junit学习笔记 不影响原有类的情况下,生成测试类。 IDEA下载插件 插件名字为JUnitGeneratorV2.0 方法 Setting选项中的Plugins里搜索下载 Setting选项中的JUnnit Generator中的Output Path配置测试类存放位置 ${SOURCEPATH}/../test/${PACKAGE}/${FILENAME}

  6. 对于线性回归通俗理解的笔记

    经常听说线性回归(Linear Regression) 到底什么才是线性,什么才是回归 有学者说,线性回归模型是一切模型之母。所以,我们的机器学习之旅,也将从这个模型开始! 建立回归模型的好处:随便给一个x,就能通过模型算出y,这个y可能和实际值不一样,这个y是

  7. RabbitMQ消息中间件(第二章)第二部分-笔记-快速搭建与控制台介绍

    消息生产与消费 ConnectionFactory: 获取连接工厂 Connection:一个连接 Channel:数据通信信道,可发送和接收消息 Queue:具体的消息存储队列 Producer Consumer 生产和消费者 代码演示 引入maven依赖 dependency groupIdcom.rabbitmq/groupId artifactIdamqp

  8. RabbitMQ消息中间件(第二章)第一部分-笔记

    本章导航 互联网大厂为什么选择RabbitMQ? RabbitMQ的高性能之道是如何做到? 什么是AMQP高级协议? AMQP核心概念是什么? RabbitMQ整体架构模型是什么样子的? RabbitMQ消息是如何流转的? RabbitMQ安装与使用 命令行与管控台 RabbitMQ消息生产与消费 Rabbit

  9. RabbitMQ消息中间件(第二章)第四部分-笔记

    Binging-绑定 Exchange和Exchange、Queue之间的连接关系 Binging可以包含RoutingKey或者参数 Queue-消息队列 消息队列,实际存储消息数据 Durability:是否持久化,Durable:是,Transient:否 Auto delete:如选yes,代表最后一个监听被移除之后,该Queue会

  10. vue2源码-响应式处理(学习笔记)-2

    回顾vue的使用 index.html文件 div id="app"{{name}}/div!-- 对数据进行渲染 -- script src="./dist/vue.js"/script !-- 引入vue(自己准备手写的) -- script //viewModel 数据模型 //典型的MVVM View vm model let vm = new Vue({//vue的使用首先要挂载 //

每天更新java,php,javaScript,go,python,nodejs,vue,android,mysql等相关技术教程,教程由网友分享而来,欢迎大家分享IT技术教程到本站,帮助自己同时也帮助他人!

Copyright 2020, All Rights Reserved. Powered by 跳墙网(www.tqwba.com)|网站地图|关键词