aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/vk/utils/bar_chart.py
diff options
context:
space:
mode:
Diffstat (limited to 'vk/utils/bar_chart.py')
-rw-r--r--vk/utils/bar_chart.py182
1 files changed, 182 insertions, 0 deletions
diff --git a/vk/utils/bar_chart.py b/vk/utils/bar_chart.py
new file mode 100644
index 0000000..f051efc
--- /dev/null
+++ b/vk/utils/bar_chart.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import matplotlib.pyplot as plt
+from matplotlib import ticker
+import numpy as np
+
+
+class BarChartBuilder:
+ _BAR_HEIGHT = .5
+
+ THICK_BAR_HEIGHT = _BAR_HEIGHT
+ THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2
+
+ def __init__(self, labels_align_middle=True):
+ self._fig, self._ax = plt.subplots()
+ self.labels_align_middle = labels_align_middle
+
+ def set_title(self, title):
+ self._ax.set_title(title)
+
+ def _get_categories_axis(self):
+ return self._ax.get_yaxis()
+
+ def _get_values_axis(self):
+ return self._ax.get_xaxis()
+
+ def set_categories_axis_limits(self, start=None, end=None):
+ bottom, top = self._ax.get_ylim()
+ if start is not None:
+ bottom = start
+ if end is not None:
+ top = end
+ self._ax.set_ylim(bottom=bottom, top=top)
+
+ def set_values_axis_limits(self, start=None, end=None):
+ left, right = self._ax.get_xlim()
+ if start is not None:
+ left = start
+ if end is not None:
+ right = end
+ self._ax.set_xlim(left=left, right=right)
+
+ def enable_grid_for_categories(self):
+ self._get_categories_axis().grid()
+
+ def enable_grid_for_values(self):
+ self._get_values_axis().grid()
+
+ def get_categories_labels(self):
+ return self._get_categories_axis().get_ticklabels()
+
+ def get_values_labels(self):
+ return self._get_values_axis().get_ticklabels()
+
+ def hide_categories(self):
+ self._get_categories_axis().set_major_locator(ticker.NullLocator())
+
+ def set_value_label_formatter(self, fn):
+ self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn))
+
+ def any_values(self):
+ self._get_values_axis().set_major_locator(ticker.AutoLocator())
+
+ def only_integer_values(self):
+ self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True))
+
+ @staticmethod
+ def set_property(*args, **kwargs):
+ plt.setp(*args, **kwargs)
+
+ def _set_size(self, inches, dim=0):
+ fig_size = self._fig.get_size_inches()
+ assert len(fig_size) == 2
+ fig_size[dim] = inches
+ self._fig.set_size_inches(fig_size, forward=True)
+
+ def set_width(self, inches):
+ self._set_size(inches)
+
+ def set_height(self, inches):
+ self._set_size(inches, dim=1)
+
+ _DEFAULT_VALUES_AXIS_MAX = 1
+ assert _DEFAULT_VALUES_AXIS_MAX > 0
+
+ def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT):
+ numof_bars = len(categories)
+ inches_per_bar = 2 * bar_height
+ categories_axis_max = inches_per_bar * numof_bars
+
+ if not numof_bars:
+ categories_axis_max += inches_per_bar
+
+ self.set_height(categories_axis_max)
+ self.set_categories_axis_limits(0, categories_axis_max)
+
+ if not numof_bars:
+ self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX)
+ self.hide_categories()
+ return []
+
+ bar_offset = inches_per_bar / 2
+ bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset
+
+ if self.labels_align_middle:
+ self._get_categories_axis().set_ticks(bar_offsets)
+ else:
+ self._get_categories_axis().set_ticks(bar_offsets - bar_offset)
+
+ self._get_categories_axis().set_ticklabels(categories)
+
+ bars = self._ax.barh(bar_offsets, values, align='center',
+ height=bar_height)
+
+ if min(values) >= 0:
+ self.set_values_axis_limits(start=0)
+ if np.isclose(max(values), 0.):
+ self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX)
+ elif max(values) < 0:
+ self.set_values_axis_limits(end=0)
+
+ return bars
+
+ @staticmethod
+ def show():
+ plt.show()
+
+ def save(self, path):
+ self._fig.savefig(path, bbox_inches='tight')
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--categories', nargs='*', metavar='LABEL',
+ default=[])
+ parser.add_argument('--values', nargs='*', metavar='N',
+ default=[], type=float)
+
+ parser.add_argument('--output', '-o', help='set output file path')
+
+ parser.add_argument('--align-middle', action='store_true',
+ dest='labels_align_middle',
+ help='align labels to the middle of the bars')
+
+ parser.add_argument('--integer-values', action='store_true',
+ dest='only_integer_values')
+ parser.add_argument('--any-values', action='store_false',
+ dest='only_integer_values')
+
+ parser.add_argument('--grid-categories', action='store_true')
+ parser.add_argument('--grid-values', action='store_true')
+
+ args = parser.parse_args()
+
+ if len(args.categories) < len(args.values):
+ parser.error('too many bar values')
+ if len(args.categories) > len(args.values):
+ args.values.extend([0.] * (len(args.categories) - len(args.values)))
+
+ builder = BarChartBuilder(labels_align_middle=args.labels_align_middle)
+
+ if args.only_integer_values:
+ builder.only_integer_values()
+ else:
+ builder.any_values()
+
+ if args.grid_categories:
+ builder.enable_grid_for_categories()
+ if args.grid_values:
+ builder.enable_grid_for_values()
+
+ builder.plot_bars(args.categories, args.values)
+
+ if args.output is None:
+ builder.show()
+ else:
+ builder.save(args.output)