From dfeff57244b2d25c135a75e326140794dc4417b4 Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Wed, 29 Nov 2023 09:41:37 +0100 Subject: modernize plotting a bit Use the "object-oriented" interface, fix image size, etc. --- algorithms/params.py | 10 ++++------ algorithms/plotter.py | 42 ++++++++++-------------------------------- 2 files changed, 14 insertions(+), 38 deletions(-) (limited to 'algorithms') diff --git a/algorithms/params.py b/algorithms/params.py index 787cfd1..486cb0e 100644 --- a/algorithms/params.py +++ b/algorithms/params.py @@ -132,12 +132,10 @@ class AlgorithmParameters: ys = [y * units.get_factor() for y in ys] plot_builder = PlotBuilder() - plot_builder.show_grid() - plot_builder.set_xlabel(self._format_plot_xlabel()) - plot_builder.set_ylabel(self._format_plot_ylabel(units)) - #plot_builder.set_yticklabels_scientific() - plot_builder.set_title(self._format_plot_title()) - plot_builder.plot(xs, ys) + title = self._format_plot_title() + xlabel = self._format_plot_xlabel() + ylabel = self._format_plot_ylabel(units) + plot_builder.plot(title, xlabel, ylabel, xs, ys) if output_path is None: plot_builder.show() else: diff --git a/algorithms/plotter.py b/algorithms/plotter.py index 47d2929..d3e4ee4 100644 --- a/algorithms/plotter.py +++ b/algorithms/plotter.py @@ -7,41 +7,19 @@ import matplotlib.pyplot as plt class PlotBuilder: - @staticmethod - def set_xlabel(s): - plt.xlabel(s) - - @staticmethod - def set_ylabel(s): - plt.ylabel(s) + def __init__(self): + self._fig, self._ax = plt.subplots(figsize=(8, 6), dpi=200) + self._ax.grid(alpha=0.8, linestyle=':') - @staticmethod - def set_yticklabels_scientific(): - plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) - - @staticmethod - def show_grid(): - plt.grid() - - @staticmethod - def set_title(s): - plt.title(s) - - @staticmethod - def set_suptitle(s): - plt.suptitle(s) - - @staticmethod - def plot(xs, ys): - plt.plot(xs, ys) + def plot(self, title, xlabel, ylabel, xs, ys): + self._ax.set_title(title) + self._ax.set_xlabel(xlabel) + self._ax.set_ylabel(ylabel) + self._ax.plot(xs, ys) @staticmethod def show(): plt.show() - @staticmethod - def save(output_path, tight=False): - if tight: - plt.savefig(output_path, bbox_inches='tight') - else: - plt.savefig(output_path) + def save(self, output_path): + self._fig.savefig(output_path) -- cgit v1.2.3