From 622206cacb166e8ac65b05733582fd404042cc0b Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Tue, 8 Mar 2016 17:43:55 +0300 Subject: refactoring --- plot.py | 121 +++++++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 48 deletions(-) (limited to 'plot.py') diff --git a/plot.py b/plot.py index 33743aa..99452c4 100644 --- a/plot.py +++ b/plot.py @@ -2,52 +2,77 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -from time import clock +from enum import Enum import gc +from time import clock + +class InputKind(Enum): + SORTED, RANDOMIZED, REVERSED = 'sorted', 'randomized', 'reversed' + + def __str__(self): + return self.value + +class SortingAlgorithm(Enum): + BUBBLE = 'bubble' + BUBBLE_OPTIMIZED = 'bubble_optimized' + HEAP = 'heap' + INSERTION = 'insertion' + MERGE = 'merge' + QUICK_FIRST = 'quick_first' + QUICK_SECOND = 'quick_second' + QUICK_MIDDLE = 'quick_middle' + QUICK_LAST = 'quick_last' + QUICK_RANDOM = 'quick_random' + SELECTION = 'selection' -_ALGORITHMS = ( - 'bubble', - 'bubble_optimized', - 'heap', - 'insertion', - 'merge', - 'quick_first', - 'quick_second', - 'quick_middle', - 'quick_last', - 'quick_random', - 'selection', -) + def __str__(self): + return self.value def _get_context(): + def natural_number(s): + n = int(s) + if n < 0: + raise argparse.ArgumentTypeError('cannot be negative') + return n + def positive_number(s): + n = int(s) + if n < 1: + raise argparse.ArgumentTypeError('must be positive') + return n + def input_kind(s): + try: + return InputKind(s) + except ValueError: + raise argparse.ArgumentError() + def sorting_algorithm(s): + try: + return SortingAlgorithm(s) + except ValueError: + raise argparse.ArgumentError() import argparse parser = argparse.ArgumentParser() - parser.add_argument('--repetitions', '-r', type=int, default=1, + parser.add_argument('--repetitions', '-r', + type=positive_number, default=1, help='set number of sorting repetitions') parser.add_argument('--input', '-i', - default='randomized', metavar='INPUT', - choices=('sorted', 'randomized', 'reversed'), + choices=tuple(x for x in InputKind), + type=input_kind, default=InputKind.RANDOMIZED, help='choose initial input state') - parser.add_argument('--algorithm', '-l', metavar='ALGORITHM', - choices=_ALGORITHMS, required=True, + parser.add_argument('--algorithm', '-l', required=True, + choices=tuple(x for x in SortingAlgorithm), + type=sorting_algorithm, help='select sorting algorithm to use') - parser.add_argument('--min', '-a', type=int, required=True, - help='set min input length', - dest='min_input_length') - parser.add_argument('--max', '-b', type=int, required=True, - help='set max input length', - dest='max_input_length') + parser.add_argument('--min', '-a', type=natural_number, + required=True, dest='min_input_length', + help='set min input length') + parser.add_argument('--max', '-b', type=natural_number, + required=True, dest='max_input_length', + help='set max input length') parser.add_argument('--output', '-o', dest='plot_path', help='set plot output path') args = parser.parse_args() - if args.repetitions < 1: - parser.error('number of repetitions must be > 0') - if args.min_input_length < 0: - parser.error('min sequence length must be >= 0') - if args.max_input_length < 0: - parser.error('max sequence length must be >= 0') if args.max_input_length < args.min_input_length: - parser.error('max sequence length cannot be less than min sequence length') + parser.error('max input length cannot be less than min input length') return args def get_timestamp(): @@ -57,16 +82,16 @@ def init_clock(): get_timestamp() def gen_input(args, n): - if args.input == 'sorted': + if args.input is InputKind.SORTED: return list(range(n)) - elif args.input == 'reversed': + elif args.input is InputKind.REVERSED: return sorted(range(n), reverse=True) - elif args.input == 'randomized': + elif args.input is InputKind.RANDOMIZED: from random import sample return sample(range(n), n) else: raise NotImplementedError( - 'unimplemented initial input state \'{}\''.format(args.input)) + 'invalid initial input state \'{}\''.format(args.input)) def measure_running_time(ctx, algorithm, input_length): xs = gen_input(ctx, input_length) @@ -102,42 +127,42 @@ def plot_algorithm(ctx, algorithm): plt.show() def plot(ctx): - if ctx.algorithm == 'bubble': + if ctx.algorithm is SortingAlgorithm.BUBBLE: from bubble_sort import bubble_sort plot_algorithm(ctx, bubble_sort) - elif ctx.algorithm == 'bubble_optimized': + elif ctx.algorithm is SortingAlgorithm.BUBBLE_OPTIMIZED: from bubble_sort import bubble_sort_optimized plot_algorithm(ctx, bubble_sort_optimized) - elif ctx.algorithm == 'heap': + elif ctx.algorithm is SortingAlgorithm.HEAP: from heapsort import heapsort plot_algorithm(ctx, heapsort) - elif ctx.algorithm == 'insertion': + elif ctx.algorithm is SortingAlgorithm.INSERTION: from insertion_sort import insertion_sort plot_algorithm(ctx, insertion_sort) - elif ctx.algorithm == 'merge': + elif ctx.algorithm is SortingAlgorithm.MERGE: from merge_sort import merge_sort plot_algorithm(ctx, merge_sort) - elif ctx.algorithm == 'quick_first': + elif ctx.algorithm is SortingAlgorithm.QUICK_FIRST: from quicksort import quicksort_first plot_algorithm(ctx, quicksort_first) - elif ctx.algorithm == 'quick_second': + elif ctx.algorithm is SortingAlgorithm.QUICK_SECOND: from quicksort import quicksort_second plot_algorithm(ctx, quicksort_second) - elif ctx.algorithm == 'quick_middle': + elif ctx.algorithm is SortingAlgorithm.QUICK_MIDDLE: from quicksort import quicksort_middle plot_algorithm(ctx, quicksort_middle) - elif ctx.algorithm == 'quick_last': + elif ctx.algorithm is SortingAlgorithm.QUICK_LAST: from quicksort import quicksort_last plot_algorithm(ctx, quicksort_last) - elif ctx.algorithm == 'quick_random': + elif ctx.algorithm is SortingAlgorithm.QUICK_RANDOM: from quicksort import quicksort_random plot_algorithm(ctx, quicksort_random) - elif ctx.algorithm == 'selection': + elif ctx.algorithm is SortingAlgorithm.SELECTION: from selection_sort import selection_sort plot_algorithm(ctx, selection_sort) else: raise NotImplementedError( - 'unknown algorithm \'{}\''.format(ctx.algorithm)) + 'invalid sorting algorithm \'{}\''.format(ctx.algorithm)) if __name__ == '__main__': ctx = _get_context() -- cgit v1.2.3