From 0d2dd291bd2c5ee7e42b2bf240742b516ad4dbda Mon Sep 17 00:00:00 2001 From: Wilson Wang Date: Sun, 6 Jun 2021 00:00:39 -0700 Subject: [PATCH] tools: rw-heatmaps output format bug fix --- tools/rw-heatmaps/plot_data.py | 63 ++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/tools/rw-heatmaps/plot_data.py b/tools/rw-heatmaps/plot_data.py index 497088a18..217eb6f8a 100755 --- a/tools/rw-heatmaps/plot_data.py +++ b/tools/rw-heatmaps/plot_data.py @@ -16,7 +16,8 @@ params = None def parse_args(): - parser = argparse.ArgumentParser(description='plot graph using mixed read/write result file.') + parser = argparse.ArgumentParser( + description='plot graph using mixed read/write result file.') parser.add_argument('input_file_a', type=str, help='first input data files in csv format. (required)') parser.add_argument('input_file_b', type=str, nargs='?', @@ -52,7 +53,8 @@ def load_data_files(*args): param_str = '' if len(param_df) != 0: param_str = param_df['comment'].iloc[0] - new_df = df[df['type'] == 'DATA'][['ratio', 'conn_size', 'value_size']].copy() + new_df = df[df['type'] == 'DATA'][[ + 'ratio', 'conn_size', 'value_size']].copy() cols = [x for x in df.columns if x.find('iter') != -1] tmp = [df[df['type'] == 'DATA'][x].str.split(':') for x in cols] @@ -189,7 +191,8 @@ def plot_data(title, plot_type, cmap_name_default, *args): count += 1 plt.subplot(4, 2, count) plt.tripcolor(df['conn_size'], df['value_size'], df[plot_type]) - plt.title('R/W Ratio {:.4f}'.format(val)) + plt.title('R/W Ratio {:.4f} [{:.2f}, {:.2f}]'.format(val, df[plot_type].min(), + df[plot_type].max())) plt.yscale('log', base=2) plt.ylabel('Value Size') plt.xscale('log', base=2) @@ -204,17 +207,17 @@ def plot_data(title, plot_type, cmap_name_default, *args): df1 = args[1]['dataframe'] df1param = args[1]['param'] fig = plt.figure(figsize=fig_size) - count = 0 + col = 0 delta_df = df1.copy() delta_df[[plot_type]] = ((df1[[plot_type]] - df0[[plot_type]]) / - df0[[plot_type]]) * 100 + df0[[plot_type]]) * 100 for tmp in [df0, df1, delta_df]: - count += 1 - count2 = count + row = 0 for val, df in tmp.groupby('ratio'): - plt.subplot(8, 3, count2) + pos = row * 3 + col + 1 + plt.subplot(8, 3, pos) norm = None - if count2 % 3 == 0: + if col == 2: cmap_name = 'bwr' if params.zero: norm = CenteredNorm() @@ -223,38 +226,45 @@ def plot_data(title, plot_type, cmap_name_default, *args): plt.tripcolor(df['conn_size'], df['value_size'], df[plot_type], norm=norm, cmap=plt.get_cmap(cmap_name)) - if count2 == 1: - plt.title('{}\nR/W Ratio {:.4f}'.format( - os.path.basename(params.input_file_a), - val)) - elif count2 == 2: - plt.title('{}\nR/W Ratio {:.4f}'.format( - os.path.basename(params.input_file_b), - val)) - elif count2 == 3: - plt.title('Gain\nR/W Ratio {:.4f} [{:.2f}%, {:.2f}%]'.format(val, df[plot_type].min(), - df[plot_type].max())) + if row == 0: + if col == 0: + plt.title('{}\nR/W Ratio {:.4f} [{:.1f}, {:.1f}]'.format( + os.path.basename(params.input_file_a), + val, df[plot_type].min(), df[plot_type].max())) + elif col == 1: + plt.title('{}\nR/W Ratio {:.4f} [{:.1f}, {:.1f}]'.format( + os.path.basename(params.input_file_b), + val, df[plot_type].min(), df[plot_type].max())) + elif col == 2: + plt.title('Gain\nR/W Ratio {:.4f} [{:.2f}%, {:.2f}%]'.format(val, df[plot_type].min(), + df[plot_type].max())) else: - plt.title('R/W Ratio {:.4f} [{:.2f}%, {:.2f}%]'.format(val, df[plot_type].min(), - df[plot_type].max())) + if col == 2: + plt.title('R/W Ratio {:.4f} [{:.2f}%, {:.2f}%]'.format(val, df[plot_type].min(), + df[plot_type].max())) + else: + plt.title('R/W Ratio {:.4f} [{:.1f}, {:.1f}]'.format(val, df[plot_type].min(), + df[plot_type].max())) plt.yscale('log', base=2) plt.ylabel('Value Size') plt.xscale('log', base=2) plt.xlabel('Connections Amount') - if count2 % 3 == 0: + if col == 2: plt.colorbar(format='%.2f%%') else: plt.colorbar() plt.tight_layout() - count2 += 3 + row += 1 + col += 1 fig.suptitle('{} [{}]\n{} {}\n{} {}'.format( title, plot_type.upper(), os.path.basename(params.input_file_a), df0param, os.path.basename(params.input_file_b), df1param)) else: raise Exception('invalid plot input data') fig.subplots_adjust(top=0.93) - plt.savefig("{}_{}.{}".format(params.output, plot_type, params.format), format=params.format) + plt.savefig("{}_{}.{}".format(params.output, plot_type, + params.format), format=params.format) def main(): @@ -263,7 +273,8 @@ def main(): params = parse_args() result = load_data_files(params.input_file_a, params.input_file_b) for i in [('read', 'viridis'), ('write', 'plasma')]: - plot_data(params.title, i[0], i[1], *result) + plot_type, cmap_name = i + plot_data(params.title, plot_type, cmap_name, *result) if __name__ == '__main__':