From 82a674e409fce161299efeb43e1176f869af64af Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Fri, 24 Jun 2016 01:54:13 +0300 Subject: major refactoring With the focus on (re)usability. That includes adding separate modules for plotting, input generation and things like that. --- plot.py | 199 +++++++++++++++++++++++++++------------------------------------- 1 file changed, 84 insertions(+), 115 deletions(-) (limited to 'plot.py') diff --git a/plot.py b/plot.py index 6163363..45140f6 100644 --- a/plot.py +++ b/plot.py @@ -2,122 +2,91 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -from enum import Enum -import gc -import pylab -from time import clock - -class OrderType(Enum): - ASCENDING, RANDOM, DESCENDING = 'ascending', 'random', 'descending' - - def __str__(self): - return self.value - - def get_case(self): - if self is OrderType.ASCENDING: - return 'best' - elif self is OrderType.DESCENDING: - return 'worst' - elif self is OrderType.RANDOM: - return 'average' - else: - raise NotImplementedError( - 'unknown "case" for input ordering: \'{}\''.format(self)) - -def get_timestamp(): - return clock() - -def init_clock(): - get_timestamp() - -def gen_input(order, n): - if order is OrderType.ASCENDING: - return list(range(n)) - elif order is OrderType.DESCENDING: - return sorted(range(n), reverse=True) - elif order is OrderType.RANDOM: - from random import sample - return sample(range(n), n) - else: - raise NotImplementedError( - 'invalid input ordering: \'{}\''.format(order)) - -def measure_running_time(algorithm, order, xs_len, iterations): - xs = gen_input(order, xs_len) - xss = [list(xs) for _ in range(iterations)] - algorithm = algorithm.get_function() - gc.disable() - started_at = get_timestamp() - for i in range(iterations): - algorithm(xss[i]) - finished_at = get_timestamp() - gc.enable() - return finished_at - started_at - -def _decorate_plot(algorithm, iterations, order): - pylab.grid() - pylab.xlabel("Input length") - pylab.ylabel('Running time (sec), {} iteration(s)'.format(iterations)) - pylab.title("{}, {} case".format( - algorithm.get_display_name(), order.get_case())) - -def plot_algorithm(algorithm, iterations, order, min_len, max_len, output_path=None): - _decorate_plot(algorithm, iterations, order) - xs_lengths = range(min_len, max_len + 1) - running_time = [] - for xs_len in xs_lengths: - running_time.append(measure_running_time(algorithm, order, xs_len, iterations)) - pylab.plot(xs_lengths, running_time) - if output_path is None: - pylab.show() - else: - pylab.savefig(output_path) +import argparse +import sys -if __name__ == '__main__': - import algorithms.registry - - def natural_number(s): - n = int(s) - if n < 0: - raise argparse.ArgumentTypeError('cannot be negative') - return n - def positive_number(s): - n = int(s) - if n < 1: - raise argparse.ArgumentTypeError('must be positive') - return n - def input_kind(s): - try: - return OrderType(s) - except ValueError: - raise argparse.ArgumentError() - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', '-l', required=True, - choices=algorithms.registry.get_codenames(), - help='specify sorting algorithm to use') - parser.add_argument('--iterations', '-r', - type=positive_number, default=1, +from algorithms.inputgen import InputKind +from algorithms.params import AlgorithmParameters +import algorithms.registry as registry + +_DEFAULT_ITERATIONS = 100 +_DEFAULT_INPUT_KIND = InputKind.AVERAGE +_DEFAULT_MIN_LENGTH = 0 +_DEFAULT_MAX_LENGTH = 200 + +def plot_algorithm(algorithm, input_kind=_DEFAULT_INPUT_KIND, + min_len=_DEFAULT_MIN_LENGTH, + max_len=_DEFAULT_MAX_LENGTH, + iterations=_DEFAULT_ITERATIONS, + output_path=None): + + if isinstance(algorithm, str): + algorithm = registry.get(algorithm) + + params = AlgorithmParameters(algorithm, min_len, max_len, + input_kind=input_kind, + iterations=iterations) + params.plot_running_time(output_path) + +def _parse_natural_number(s): + n = int(s) + if n < 0: + raise argparse.ArgumentTypeError('must not be a negative number') + return n + +def _parse_positive_number(s): + n = int(s) + if n < 1: + raise argparse.ArgumentTypeError('must be positive') + return n + +def _parse_input_kind(s): + try: + return InputKind(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid input_kind: ' + str(s)) + +def _format_algorithm(codename): + return '* {}: {}'.format(codename, registry.get(codename).display_name) + +def _format_available_algorithms(): + descr = 'available algorithms (in the CODENAME: DISPLAY_NAME format):\n' + return descr + '\n'.join(map(_format_algorithm, registry.get_codenames())) + +def _format_description(): + return _format_available_algorithms() + +def _parse_args(args=sys.argv): + parser = argparse.ArgumentParser( + description=_format_description(), + formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument('algorithm', metavar='CODENAME', + choices=registry.get_codenames(), + help='algorithm codename') + parser.add_argument('--iterations', '-r', metavar='N', + type=_parse_positive_number, + default=_DEFAULT_ITERATIONS, help='set number of algorithm iterations') - parser.add_argument('--order', '-i', - choices=tuple(x for x in OrderType), - type=input_kind, default=OrderType.RANDOM, - help='specify input order') - parser.add_argument('--min', '-a', type=natural_number, - required=True, dest='min_len', + parser.add_argument('--input', '-i', dest='input_kind', + choices=InputKind, + type=_parse_input_kind, default=_DEFAULT_INPUT_KIND, + help='specify input kind') + parser.add_argument('--min', '-a', metavar='N', dest='min_len', + type=_parse_natural_number, + default=_DEFAULT_MIN_LENGTH, help='set min input length') - parser.add_argument('--max', '-b', type=natural_number, - required=True, dest='max_len', + parser.add_argument('--max', '-b', metavar='N', dest='max_len', + type=_parse_natural_number, + default=_DEFAULT_MAX_LENGTH, help='set max input length') - parser.add_argument('--output', '-o', dest='output_path', - help='set plot output path') - args = parser.parse_args() - if args.max_len < args.min_len: - parser.error('max input length cannot be less than min input length') - - init_clock() - plot_algorithm(algorithms.registry.get(args.algorithm), - args.iterations, args.order, - args.min_len, args.max_len, - args.output_path) + parser.add_argument('--output', '-o', metavar='PATH', dest='output_path', + help='set plot file path') + + return parser.parse_args(args[1:]) + +def main(args=sys.argv): + plot_algorithm(**vars(_parse_args(args))) + +if __name__ == '__main__': + main() -- cgit v1.2.3