#!/usr/bin/python3
# _*_coding:utf-8 _*_
# @Time :2021/2/21 23:14
# @Author :jory.d
# @File :roc_auc.py
# @Software :PyCharm
# @Desc: 绘制多分类的ROC AUC曲线
import matplotlib as mpl
# mpl.use('Agg') # Agg TkAgg
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import label_binarize
import random
from pprint import pprint
np.set_printoptions(precision=2)
def get_other_metrics(label_names, y_trues, y_probs):
"""
计算分类指标, P, R, F1
"""
assert type(label_names) is list
assert type(y_trues) is list
assert type(y_probs) is list
assert len(y_trues) == len(y_probs)
y_true = np.array(y_trues)
y_prob = np.array(y_probs)
y_pred = np.argmax(y_prob, axis=-1)
Precision = metrics.precision_score(y_true, y_pred, average=None)
Recall = metrics.recall_score(y_true, y_pred, average=None)
F1_Score = metrics.f1_score(y_true, y_pred, average=None)
return Precision, Recall, F1_Score
def get_cmap(N):
'''
Returns a function that maps each index in 0, 1,.. . N-1 to a distinct
RGB color.
'''
import matplotlib.cm as cmx
import matplotlib.colors as colors
color_norm = colors.Normalize(vmin=0, vmax=N - 1)
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
def map_index_to_rgb_color(index):
return scalar_map.to_rgba(index)
return map_index_to_rgb_color
def create_roc_auc(label_names, y_trues, y_probs, png_save_path, is_show=True):
"""
使用sklearn得api计算ROC,并绘制曲线
:param label_names:
:param y_trues:
:param y_probs:
:param png_save_path:
:param is_show:
:return:
"""
assert type(label_names) is list
assert type(y_trues) is list
assert type(y_probs) is list
assert len(y_trues) == len(y_probs)
labels = list(label_names)
n_classes = len(label_names)
y_true = np.array(y_trues)
y_prob = np.array(y_probs)
y_true_one_hot = label_binarize(y_true, np.arange(n_classes)) # 装换成类似二进制的编码
# Compute ROC curve and ROC area for each class
fpr, tpr, roc_auc = {}, {}, {}
for i in range(n_classes):
fpr[i], tpr[i], thres = metrics.roc_curve(y_true_one_hot[:, i], y_prob[:, i])
roc_auc[i] = metrics.auc(fpr[i], tpr[i])
pprint(fpr)
pprint(tpr)
print('AUC: {}'.format(roc_auc))
mpl.rcParams['font.sans-serif'] = u'DejaVu Sans' # DejaVu Sans SimHei
mpl.rcParams['axes.unicode_minus'] = False
fig = plt.figure()
color = ('b', 'g', 'r', 'c', 'm', 'y', 'k', 'w')
cmap = get_cmap(n_classes)
# Plot of a ROC curve for a specific class
for i in range(n_classes):
# FPR就是横坐标,TPR就是纵坐标
_col = cmap(i) if n_classes > len(color) else color[i]
plt.plot(fpr[i], tpr[i], c=_col, lw=2, alpha=0.7, label=u'%d AUC=%.3f' % (i, roc_auc[i]))
plt.plot((0, 1), (0, 1), c='#808080', lw=1, ls='--', alpha=0.7)
plt.xlim((-0.01, 1.02))
plt.ylim((-0.01, 1.02))
plt.xticks(np.arange(0, 1.1, 0.1))
plt.yticks(np.arange(0, 1.1, 0.1))
plt.xlabel('False Positive Rate', fontsize=13)
plt.ylabel('True Positive Rate', fontsize=13)
plt.grid(b=True, ls=':')
plt.legend(loc='lower right', fancybox=True, framealpha=0.8, fontsize=12)
plt.title(u'ROC curve', fontsize=17)
plt.savefig(png_save_path, format='png')
if is_show:
plt.show()
return fig
def create_roc_self(label_names, y_trues, y_probs, png_save_path, is_show=True):
"""
python 实现计算tpr, fpr; 同时统计多个阈值下每个class的指标,用于后处理时选择最优阈值
:param label_names:
:param y_trues:
:param y_probs:
:param png_save_path:
:param is_show:
:return:
"""
assert type(label_names) is list
assert type(y_trues) is list
assert type(y_probs) is list
assert len(y_trues) == len(y_probs)
n_classes = len(label_names)
y_trues = np.array(y_trues)
y_probs = np.array(y_probs)
bs = y_probs.shape[0]
y_trues_one_hot = label_binarize(y_trues, np.arange(n_classes)) # 装换成类似二进制的编码
print(y_trues)
print(y_trues_one_hot)
tpr_dict, fpr_dict = {}, {}
thresh = [i / 10 for i in range(1, 11)]
# y_pred = np.argmax(y_probs, axis=1) # [n,]
for i in range(n_classes):
tpr_dict[i] = []
fpr_dict[i] = []
y_true = y_trues_one_hot[:, i] # [n,]
y_pred_prob = y_probs[:, i]
# 计算下0.1~1.0这每个阈值下的tpr, fpr
for th in thresh:
# tpr = tp/(tp+fn), fpr = fp/(tn+fp)
# y_pred_prob = np.array([y_probs[i, y_pred[i]] for i in range(bs)]) # [n,]
y_pred2 = np.where(y_pred_prob >= th, 1, 0)
tp = np.sum(y_pred2[y_true == 1] == 1)
fn = np.sum(y_pred2[y_true == 1] == 0)
fp = np.sum(y_pred2[y_true == 0] == 1)
tn = np.sum(y_pred2[y_true == 0] == 0)
tpr = tp / (tp + fn + 1e-5)
fpr = fp / (tn + fp + 1e-5)
print(f'thres={th}, tpr={tpr}, fpr={fpr}')
tpr_dict[i].append(round(tpr, 2))
fpr_dict[i].append(round(fpr, 2))
pprint('tpr: {}'.format(tpr_dict))
pprint('fpr: {}'.format(fpr_dict))
cols = 2
rows = round(n_classes / cols)
fig = plt.figure(figsize=(12, 12), dpi=150)
fig.suptitle('per class tpr and fpr', fontsize='xx-large')
for r in range(rows):
for c in range(cols):
id = r * cols + c
if id > n_classes - 1: break
ax = fig.add_subplot(rows, cols, id + 1)
x = thresh
ax.plot(x, tpr_dict[id], c='b', label='tpr')
ax.plot(x, fpr_dict[id], c='r', label='fpr')
ax.set_xlabel('thres', fontsize='x-large')
ax.set_ylabel('tpr_fpr', fontsize='x-large')
plt.xticks(np.arange(0, 1.1, 0.2))
plt.yticks(np.arange(0, 1.1, 0.2))
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower right', fontsize='x-large')
plt.savefig(png_save_path, format='png')
if is_show:
plt.show()
return fig
# 计算每个class的 fpr, tpr
np.random.seed(888)
if __name__ == '__main__':
labels = ['A', 'B', 'C']
batch_size = 100
# 真值和预测值
y_true = np.random.randint(0, len(labels), [batch_size]).tolist()
y_prob = np.random.random([batch_size, len(labels)]).tolist()
# _ = create_roc_auc(labels, y_true, y_prob, './ss1.png')
_ = create_roc_self(labels, y_true, y_prob, './ss2.png')