Source code for crosspredict.report_binary._report_binary

from typing import List, Tuple, Dict
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from ._curves import PrecisionRecallCurve, RocAucCurve, MeanProbCurve, GenGINICurve, DistributionCurve


[docs]class ReportBinary: """ Makes report for binary classification problem :param cols_score: List[str] List of column names with model probabilities :param cols_target: List[str] List of column names with true binary labels :param col_generation_apps: str Column name with month of event date (for PSI calculation, does not need true labels in all rows) :param col_generation_deals: str Column name with month of event date (for metric calculation, only for data with true labels) """
[docs] def __init__(self, cols_score: List[str], cols_target: List[str], col_generation_apps: str=None, col_generation_deals: str=None ): self._report_shape = None self._report = None self._product_plots = None self.fig = None self._col_generation_deals = col_generation_deals self._col_generation_apps = col_generation_apps self._cols_score = cols_score self._cols_target = cols_target tuples = [(i, j) for i, j in product(self._cols_score, self._cols_target)] index = pd.MultiIndex.from_tuples( tuples, names=['cols_score', 'cols_target']) self.stats = pd.DataFrame(index=index)
[docs] def fit(self, df: pd.DataFrame) -> 'ReportBinary': """ Precalculates metrics and statistics for given pd.DataFrame :param df: pd.DataFrame :return: self class """ assert all([i in df.columns for i in self._cols_score]) assert all([i in df.columns for i in self._cols_target]) if self._col_generation_deals: assert self._col_generation_deals in df.columns, f"DataFrame does not have column col_generation_deals=\"{self._col_generation_deals}\"" if self._col_generation_apps: assert self._col_generation_apps in df.columns, f"DataFrame does not have column _col_generation_apps=\"{self._col_generation_apps}\"" mask_score = ~(df[self._cols_score].isnull().any(axis=1)) df_app = df[mask_score] for col_target in self._cols_target: mask_target = ~df[col_target].isnull() df_deal = df[mask_score & mask_target] for col_score in self._cols_score: self.stats.at[(col_score, col_target), 'distribution'] = DistributionCurve( col_score, col_target).fit(df_deal) if self._col_generation_deals is not None: self.stats.at[(col_score, col_target), 'gen-gini'] = GenGINICurve(col_score, col_target, self._col_generation_deals).fit(df_deal) self.stats.at[(col_score, col_target), 'mean-prob'] = MeanProbCurve(col_score, col_target).fit(df_deal) self.stats.at[(col_score, col_target), 'roc-auc'] = RocAucCurve(col_score, col_target).fit(df_deal) self.stats.at[(col_score, col_target), 'precision-recall'] = PrecisionRecallCurve(col_score, col_target).fit(df_deal) return self
def _draw_template(self): """ отрисовка шаблона графиков """ height = self._report_shape[0] width = self._report_shape[1] self.fig = plt.gcf() self.fig.set_size_inches(width * 6, height * 5) plt.subplots_adjust(left=0.125, right=0.9, bottom=0.1, top=0.95, wspace=0.35, hspace=0.6 ) for report, locations in self._report.items(): res = [] if report == 'roc-auc': if (isinstance(locations, tuple)) | ( isinstance(locations, list)): assert len(locations) == 1, f'Location of Roc-Auc plot should have `len`==1, passed `len`={len(locations)}' loc = locations[0] else: loc = locations ax = plt.subplot2grid((height, width), **loc) res = {'ax': ax} elif report == 'precision-recall': if (isinstance(locations, tuple)) | ( isinstance(locations, list)): assert len(locations) == 1, f'Location of Precision-Recall plot should have `len`==1, passed `len`={len(locations)}' loc = locations[0] else: loc = locations ax = plt.subplot2grid((height, width), **loc) res = {'ax': ax} elif report == 'mean-prob': res = [] if (isinstance(locations, tuple)) | ( isinstance(locations, list)): for loc in locations: res.append( {'ax': plt.subplot2grid((height, width), **loc)}) else: loc = locations res.append( {'ax': plt.subplot2grid((height, width), **loc)}) assert len(res) == self._product_plots, f'Location of Mean-Probability plot should have `len`=={self._product_plots}, passed `len`={len(res)}' elif report == 'gen-gini': assert self._col_generation_deals is not None, f'To plot GINI by generations you need to pass `col_generation_deals`' if (isinstance(locations, tuple)) | ( isinstance(locations, list)): assert len(locations) == 1, f'Location of GINI by generations plot should have `len`==1, passed `len`={len(locations)}' loc = locations[0] else: loc = locations ax = plt.subplot2grid((height, width), **loc) res = {'ax': ax, 'ax_twinx': ax.twinx()} elif report == 'distribution': res = [] if (isinstance(locations, tuple)) | ( isinstance(locations, list)): for loc in locations: res.append( {'ax': plt.subplot2grid((height, width), **loc)}) else: loc = locations res.append( {'ax': plt.subplot2grid((height, width), **loc)}) assert len(res) == self._product_plots, f'Location of Distribution plot should have `len`=={self._product_plots}, passed `len`={len(res)}' self._report[report] = res
[docs] def plot_report(self, report_shape: List[int], report: Dict = None, cols_score: List[str] = None, cols_target: List[str] = None): """ Plots report of given configuration. :param report_shape: List[int] Shape of subplot axes. Read more https://matplotlib.org/3.1.1/gallery/userdemo/demo_gridspec01.html#sphx-glr-gallery-userdemo-demo-gridspec01-py :param report: Dict Dict with reports and their location. Read more https://matplotlib.org/3.1.1/gallery/userdemo/demo_gridspec01.html#sphx-glr-gallery-userdemo-demo-gridspec01-py :param cols_score: List[str] SubList of column names with model probabilities :param cols_target: List[str] SubList of column names with true binary labels :return: """ if cols_score is None: cols_score = self._cols_score else: cols = [col for col in cols_score if col not in self._cols_score] assert len(cols) == 0, f"Columns {cols} did not calculated in `fit` method. `cols_score`={self._cols_score} were passed in ReportBinary class" if cols_target is None: cols_target = self._cols_target else: cols = [col for col in cols_target if col not in self._cols_target] assert len(cols) == 0, f"Columns {cols} did not calculated in `fit` method. `cols_target`={self._cols_target} were passed in ReportBinary class" self._product_plots = len(cols_score) * len(cols_target) self._report_shape = report_shape self._report = report self._draw_template() for report, locations in self._report.items(): if report == 'roc-auc': for col_score, col_target in product(cols_score, cols_target): self.stats.at[(col_score, col_target), report].plot( self._report[report]['ax'], title='Roc-Auc') elif report == 'precision-recall': for col_score, col_target in product(cols_score, cols_target): self.stats.at[(col_score, col_target), report].plot( self._report[report]['ax'], title='Precision-Recall') elif report == 'mean-prob': for i, (col_score, col_target) in enumerate( product(cols_score, cols_target)): self.stats.at[(col_score, col_target), report].plot( self._report[report][i]['ax'], title=col_score + ' | ' + col_target) elif report == 'gen-gini': for col_score, col_target in product(cols_score, cols_target): self.stats.at[(col_score, col_target), report].plot(self._report[report]['ax'], ax_twinx=self._report[report]['ax_twinx'], title='GINI by generations') elif report == 'distribution': for i, (col_score, col_target) in enumerate( product(cols_score, cols_target)): self.stats.at[(col_score, col_target), report].plot( self._report[report][i]['ax'], title='Predict Distribution') return self