| """ |
| Validation process |
| """ |
| import sys |
| import json |
| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import mlflow |
| from matplotlib import rcParams |
| from tableone import TableOne |
|
|
|
|
| |
| rcParams['figure.figsize'] = 20, 5 |
| rcParams['axes.spines.top'] = False |
| rcParams['axes.spines.right'] = False |
|
|
|
|
| def plot_cluster_size(df, data_type): |
| """ |
| Produce a bar plot of cluster size |
| -------- |
| :param df: dataframe to plot |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| |
| fig, ax = plt.subplots() |
| df.groupby('cluster').size().plot(ax=ax, kind='barh') |
| title = "Patient Cohorts" |
| ax.set_title(title) |
| ax.set_xlabel("Number of Patients", size=20) |
| ax.set_ylabel("Cluster") |
| plt.tight_layout() |
| mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png') |
|
|
|
|
| def plot_feature_hist(df, col, data_type): |
| """ |
| Produce a histogram plot for a chosen feature |
| -------- |
| :param df: dataframe to plot |
| :param col: feature column to plot |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| fig, ax = plt.subplots() |
| df.groupby('cluster')[col].plot(ax=ax, kind='hist', alpha=0.5) |
| ax.set_xlabel(col) |
| title = col + ' Histogram' |
| ax.set_title(title, size=20) |
| ax.legend() |
| plt.tight_layout() |
| mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png') |
|
|
|
|
| def plot_feature_bar(data, col, typ, data_type): |
| """ |
| Produce a bar plot for a chosen feature |
| -------- |
| :param df: dataframe to plot |
| :param col: feature column to plot |
| :param typ: 'count' or 'percentage' |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| if typ == 'count': |
| to_plot = data.groupby(['cluster']).apply( |
| lambda x: x.groupby(col).size()) |
| x_label = "Number" |
| else: |
| to_plot = data.groupby(['cluster']).apply( |
| lambda x: 100 * x.groupby(col).size() / len(x)) |
| x_label = "Percentage" |
| fig, ax = plt.subplots() |
| to_plot.plot(ax=ax, kind='barh') |
| title = "Patient Cohorts" |
| ax.set_title(title, size=20) |
| ax.set_xlabel(x_label + " of patients") |
| ax.set_ylabel("Cluster") |
| plt.tight_layout() |
| mlflow.log_figure(fig, 'fig/' + '_'.join((title.replace(' ', '_'), col, data_type + '.png'))) |
|
|
|
|
| def plot_cluster_bar(data, typ, data_type): |
| """ |
| Produce a bar plot for a chosen feature |
| -------- |
| :param data: data to plot |
| :param typ: 'count' or 'percentage' |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| fig, ax = plt.subplots() |
| data.plot(ax=ax, kind='bar') |
| ax.set_title(typ, size=20) |
| ax.set_xlabel("Cluster") |
| ax.set_ylabel("Percentage") |
| ax.set_ylim(0, 100) |
| plt.tight_layout() |
| mlflow.log_figure(fig, 'fig/' + typ + '_' + data_type + '.png') |
|
|
|
|
| def plot_events(df, data_type): |
| """ |
| Plot events in the next 12 months based on metric table |
| -------- |
| :param df: metric table |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| df = df.drop('SafeHavenID', axis=1).set_index('cluster') |
| events = df.groupby('cluster').apply(lambda x: 100 * x.apply( |
| lambda x: len(x[x == 1]) / len(x))) |
| plot_cluster_bar(events, 'events', data_type) |
|
|
|
|
| def process_deceased_metrics(col): |
| """ |
| Process deceased column for plotting |
| ------- |
| :param col: column to process |
| """ |
| n_deceased = 100 * ((col[col < '12+']).count()) / len(col) |
| res = pd.DataFrame({'alive': [100 - n_deceased], 'deceased': [n_deceased]}) |
|
|
| return res |
|
|
|
|
| def plot_deceased(df, data_type): |
| """ |
| Plot events in the next 12 months based on metric table |
| -------- |
| :param df: metric table |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| survival = df.groupby('cluster')['time_to_death'].apply( |
| process_deceased_metrics).reset_index().drop( |
| 'level_1', axis=1).set_index('cluster') |
| plot_cluster_bar(survival, 'survival', data_type) |
|
|
|
|
| def plot_therapies(df_year, results, data_type): |
| """ |
| Plot patient therapies per cluster |
| -------- |
| :param df_year: unscaled data for current year |
| :param results: cluster results and safehaven id |
| :param data_type: type of data - train, test, val, rec, sup |
| """ |
| |
| therapies = df_year[['SafeHavenID', 'single_inhaler', 'double_inhaler', 'triple_inhaler']] |
| res_therapies = pd.merge(therapies, results, on='SafeHavenID', how='inner') |
|
|
| |
| inhaler_cols = ['single_inhaler', 'double_inhaler', 'triple_inhaler'] |
| inhals = res_therapies[['cluster'] + inhaler_cols].set_index('cluster') |
| in_res = inhals.groupby('cluster').apply( |
| lambda x: x.apply(lambda x: 100 * (x[x > 0].count()) / len(x))) |
|
|
| |
| no_in = res_therapies.groupby('cluster').apply( |
| lambda x: 100 * len(x[(x[inhaler_cols] == 0).all(axis=1)]) / len(x)).values |
|
|
| |
| in_res.columns = [c[0] for c in in_res.columns.str.split('_')] |
|
|
| |
| in_res['no_inhaler'] = no_in |
|
|
| plot_cluster_bar(in_res, 'therapies', data_type) |
|
|
|
|
| def main(): |
|
|
| |
| with open('../../../config.json') as json_config_file: |
| config = json.load(json_config_file) |
| data_path = config['model_data_path'] |
|
|
| |
| data_type = sys.argv[1] |
| run_name = sys.argv[2] |
| run_id = sys.argv[3] |
|
|
| |
| model_type = 'hierarchical' |
| experiment_name = 'Model E - Date Specific: ' + model_type |
| mlflow.set_tracking_uri('http://127.0.0.1:5000/') |
| mlflow.set_experiment(experiment_name) |
| mlflow.start_run(run_id=run_id) |
|
|
| |
| columns = np.load(data_path + run_name + '_cols.npy', allow_pickle=True) |
| df_clusters = pd.read_pickle(data_path + "_".join((run_name, data_type, 'clusters.pkl'))) |
| df_reduced = df_clusters[list(columns) + ['cluster']] |
|
|
| |
| plot_cluster_size(df_reduced, data_type) |
|
|
| |
| t1_year = TableOne(df_reduced, categorical=[], groupby='cluster', pval=True) |
| t1yr_file = data_path + 't1_year_' + run_name + '_' + data_type + '.html' |
| t1_year.to_html(t1yr_file) |
| mlflow.log_artifact(t1yr_file) |
|
|
| |
| plot_feature_hist(df_clusters, 'age', data_type) |
| plot_feature_hist(df_clusters, 'albumin_med_2yr', data_type) |
|
|
| |
| df_clusters['sex'] = df_clusters['sex_bin'].map({0: 'Male', 1: 'Female'}) |
| plot_feature_bar(df_clusters, 'sex', 'percent', data_type) |
| plot_feature_bar(df_clusters, 'simd_decile', 'precent', data_type) |
|
|
| |
| df_events = pd.read_pickle(data_path + 'metric_table_events.pkl') |
| df_counts = pd.read_pickle(data_path + 'metric_table_counts.pkl') |
| df_next = pd.read_pickle(data_path + 'metric_table_next.pkl') |
|
|
| |
| clusters = df_clusters[['SafeHavenID', 'cluster']] |
| df_events = clusters.merge(df_events, on='SafeHavenID', how='left').fillna(0) |
| df_counts = clusters.merge(df_counts, on='SafeHavenID', how='left').fillna(0) |
| df_next = clusters.merge(df_next, on='SafeHavenID', how='left').fillna('12+') |
|
|
| |
| cat_cols = df_events.columns[2:] |
| df_events[cat_cols] = df_events[cat_cols].astype('int') |
| event_limit = dict(zip(cat_cols, 5 * [1])) |
| event_order = dict(zip(cat_cols, 5 * [[1, 0]])) |
| t1_events = TableOne(df_events[df_events.columns[1:]], groupby='cluster', |
| limit=event_limit, order=event_order) |
| t1_events_file = data_path + '_'.join(('t1', data_type, 'events', run_name + '.html')) |
| t1_events.to_html(t1_events_file) |
| mlflow.log_artifact(t1_events_file) |
|
|
| |
| count_cols = df_counts.columns[2:] |
| df_counts[count_cols] = df_counts[count_cols].astype('int') |
| t1_counts = TableOne(df_counts[df_counts.columns[1:]], categorical=[], groupby='cluster') |
| t1_counts_file = data_path + '_'.join(('t1', data_type, 'counts', run_name + '.html')) |
| t1_counts.to_html(t1_counts_file) |
| mlflow.log_artifact(t1_counts_file) |
|
|
| |
| next_cols = df_next.columns[2:] |
| next_event_order = dict(zip(next_cols, 5 * [['1', '3', '6', '12', '12+']])) |
| t1_next = TableOne(df_next[df_next.columns[1:]], groupby='cluster', |
| order=next_event_order) |
| t1_next_file = data_path + '_'.join(('t1', data_type, 'next', run_name + '.html')) |
| t1_next.to_html(t1_next_file) |
| mlflow.log_artifact(t1_next_file) |
|
|
| |
| plot_events(df_events, data_type) |
| plot_deceased(df_next, data_type) |
| plot_therapies(df_clusters, clusters, data_type) |
|
|
| |
| mlflow.end_run() |
|
|
|
|
| main() |
|
|