# Copyright (c) 2017 Egor Tensin # This file is part of the "VK scripts" project. # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. import matplotlib.pyplot as plt from matplotlib import ticker import numpy as np class BarChartBuilder: _BAR_HEIGHT = .5 THICK_BAR_HEIGHT = _BAR_HEIGHT THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2 def __init__(self, labels_align_middle=True): self._fig, self._ax = plt.subplots() self.labels_align_middle = labels_align_middle def set_title(self, title): self._ax.set_title(title) def _get_categories_axis(self): return self._ax.get_yaxis() def _get_values_axis(self): return self._ax.get_xaxis() def set_categories_axis_limits(self, start=None, end=None): bottom, top = self._ax.get_ylim() if start is not None: bottom = start if end is not None: top = end self._ax.set_ylim(bottom=bottom, top=top) def set_values_axis_limits(self, start=None, end=None): left, right = self._ax.get_xlim() if start is not None: left = start if end is not None: right = end self._ax.set_xlim(left=left, right=right) def enable_grid_for_categories(self): self._get_categories_axis().grid() def enable_grid_for_values(self): self._get_values_axis().grid() def get_categories_labels(self): return self._get_categories_axis().get_ticklabels() def get_values_labels(self): return self._get_values_axis().get_ticklabels() def hide_categories(self): self._get_categories_axis().set_major_locator(ticker.NullLocator()) def set_value_label_formatter(self, fn): self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn)) def any_values(self): self._get_values_axis().set_major_locator(ticker.AutoLocator()) def only_integer_values(self): self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True)) @staticmethod def set_property(*args, **kwargs): plt.setp(*args, **kwargs) def _set_size(self, inches, dim=0): fig_size = self._fig.get_size_inches() assert len(fig_size) == 2 fig_size[dim] = inches self._fig.set_size_inches(fig_size, forward=True) def set_width(self, inches): self._set_size(inches) def set_height(self, inches): self._set_size(inches, dim=1) _DEFAULT_VALUES_AXIS_MAX = 1 assert _DEFAULT_VALUES_AXIS_MAX > 0 def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT): numof_bars = len(categories) inches_per_bar = 2 * bar_height categories_axis_max = inches_per_bar * numof_bars if not numof_bars: categories_axis_max += inches_per_bar self.set_height(categories_axis_max) self.set_categories_axis_limits(0, categories_axis_max) if not numof_bars: self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX) self.hide_categories() return [] bar_offset = inches_per_bar / 2 bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset if self.labels_align_middle: self._get_categories_axis().set_ticks(bar_offsets) else: self._get_categories_axis().set_ticks(bar_offsets - bar_offset) self._get_categories_axis().set_ticklabels(categories) bars = self._ax.barh(bar_offsets, values, align='center', height=bar_height) if min(values) >= 0: self.set_values_axis_limits(start=0) if np.isclose(max(values), 0.): self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX) elif max(values) < 0: self.set_values_axis_limits(end=0) return bars @staticmethod def show(): plt.show() def save(self, path): self._fig.savefig(path, bbox_inches='tight') if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--categories', nargs='*', metavar='LABEL', default=[]) parser.add_argument('--values', nargs='*', metavar='N', default=[], type=float) parser.add_argument('--output', '-o', help='set output file path') parser.add_argument('--align-middle', action='store_true', dest='labels_align_middle', help='align labels to the middle of the bars') parser.add_argument('--integer-values', action='store_true', dest='only_integer_values') parser.add_argument('--any-values', action='store_false', dest='only_integer_values') parser.add_argument('--grid-categories', action='store_true') parser.add_argument('--grid-values', action='store_true') args = parser.parse_args() if len(args.categories) < len(args.values): parser.error('too many bar values') if len(args.categories) > len(args.values): args.values.extend([0.] * (len(args.categories) - len(args.values))) builder = BarChartBuilder(labels_align_middle=args.labels_align_middle) if args.only_integer_values: builder.only_integer_values() else: builder.any_values() if args.grid_categories: builder.enable_grid_for_categories() if args.grid_values: builder.enable_grid_for_values() builder.plot_bars(args.categories, args.values) if args.output is None: builder.show() else: builder.save(args.output)