From 82a674e409fce161299efeb43e1176f869af64af Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Fri, 24 Jun 2016 01:54:13 +0300 Subject: major refactoring With the focus on (re)usability. That includes adding separate modules for plotting, input generation and things like that. --- algorithms/algorithm.py | 26 +++-- algorithms/impl/__init__.py | 26 +++-- algorithms/impl/bubble_sort.py | 28 ++++-- algorithms/impl/heapsort.py | 23 +++-- algorithms/impl/insertion_sort.py | 23 +++-- algorithms/impl/median.py | 32 +++--- algorithms/impl/merge_sort.py | 23 +++-- algorithms/impl/quicksort.py | 31 +++--- algorithms/impl/selection_sort.py | 23 +++-- algorithms/inputgen.py | 30 ++++++ algorithms/params.py | 110 +++++++++++++++++++++ algorithms/plotter.py | 38 ++++++++ algorithms/registry.py | 12 +-- algorithms/timer.py | 23 +++++ plot.bat | 15 +-- plot.py | 199 ++++++++++++++++---------------------- test.py | 109 ++++++++++++--------- 17 files changed, 504 insertions(+), 267 deletions(-) create mode 100644 algorithms/inputgen.py create mode 100644 algorithms/params.py create mode 100644 algorithms/plotter.py create mode 100644 algorithms/timer.py diff --git a/algorithms/algorithm.py b/algorithms/algorithm.py index 322a51f..b65ed04 100644 --- a/algorithms/algorithm.py +++ b/algorithms/algorithm.py @@ -2,24 +2,22 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +from . import inputgen + class Algorithm: def __init__(self, codename, display_name, f): - self._codename = codename - self._display_name = display_name - self._f = f - - def get_codename(self): - return self._codename - - def get_display_name(self): - return self._display_name + self.codename = codename + self.display_name = display_name + self.function = f - def get_function(self): - return self._f - - def __str__(self): - return self.get_display_name() + @staticmethod + def gen_input(n, case=inputgen.InputKind.AVERAGE): + #raise NotImplementedError('inputgen generation is not defined for generic algorithms') + return inputgen.gen_input_for_sorting(n, case) class SortingAlgorithm(Algorithm): def __init__(self, codename, display_name, f): super().__init__(codename, display_name, f) + + def gen_input(self, n, case=inputgen.InputKind.AVERAGE): + return inputgen.gen_input_for_sorting(n, case) diff --git a/algorithms/impl/__init__.py b/algorithms/impl/__init__.py index 29cd51c..dcb8fc4 100644 --- a/algorithms/impl/__init__.py +++ b/algorithms/impl/__init__.py @@ -2,18 +2,16 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -_ALL_ALGORITHMS = {} +from importlib import import_module +import os.path +from pkgutil import iter_modules -def _refresh_algorithms(): - _ALGORITHMS_NAME = '_ALGORITHMS' - global _ALL_ALGORITHMS - _ALL_ALGORITHMS = {} +from .. import algorithm - from algorithms.algorithm import Algorithm +_ALGORITHMS_NAME = '_ALGORITHMS' - from importlib import import_module - import os.path - from pkgutil import iter_modules +def refresh_algorithms(): + all_algorithms = {} for _, module_name, is_pkg in iter_modules([os.path.dirname(__file__)]): if is_pkg: @@ -21,9 +19,9 @@ def _refresh_algorithms(): module = import_module('.' + module_name, __package__) if hasattr(module, _ALGORITHMS_NAME): module_algorithms = getattr(module, _ALGORITHMS_NAME) - for algorithm in module_algorithms: - assert isinstance(algorithm, Algorithm) - assert algorithm.get_codename() not in _ALL_ALGORITHMS - _ALL_ALGORITHMS[algorithm.get_codename()] = algorithm + for descr in module_algorithms: + assert isinstance(descr, algorithm.Algorithm) + assert descr.codename not in all_algorithms + all_algorithms[descr.codename] = descr -_refresh_algorithms() + return all_algorithms diff --git a/algorithms/impl/bubble_sort.py b/algorithms/impl/bubble_sort.py index 2abfc43..e6aa645 100644 --- a/algorithms/impl/bubble_sort.py +++ b/algorithms/impl/bubble_sort.py @@ -2,6 +2,10 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import sys + +from ..algorithm import SortingAlgorithm + def bubble_sort(xs): while True: swapped = False @@ -14,7 +18,7 @@ def bubble_sort(xs): return xs def bubble_sort_optimized(xs): - n = len(xs) + n = len(xs) while True: new_n = 0 for i in range(1, n): @@ -26,14 +30,18 @@ def bubble_sort_optimized(xs): break return xs -if __name__ == '__main__': - import sys - xs = list(map(int, sys.argv[1:])) +_ALGORITHMS = [ + SortingAlgorithm('bubble_sort', 'Bubble sort', bubble_sort), + SortingAlgorithm('bubble_sort_optimized', 'Bubble sort (optimized)', bubble_sort_optimized), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) print(bubble_sort(list(xs))) print(bubble_sort_optimized(list(xs))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('bubble_sort', 'Bubble sort', bubble_sort), - SortingAlgorithm('bubble_sort_optimized', 'Bubble sort (optimized)', bubble_sort_optimized), - ] + +if __name__ == '__main__': + main() diff --git a/algorithms/impl/heapsort.py b/algorithms/impl/heapsort.py index db3b6bd..c92a04c 100644 --- a/algorithms/impl/heapsort.py +++ b/algorithms/impl/heapsort.py @@ -2,6 +2,10 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import sys + +from ..algorithm import SortingAlgorithm + # Disclaimer: implemented in the most literate way. def heapsort(xs): @@ -49,11 +53,16 @@ def _siftdown(xs, start, end): else: break +_ALGORITHMS = [ + SortingAlgorithm('heapsort', 'Heapsort', heapsort), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) + print(heapsort(list(xs))) + if __name__ == '__main__': - import sys - print(heapsort(list(map(int, sys.argv[1:])))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('heapsort', 'Heapsort', heapsort), - ] + main() diff --git a/algorithms/impl/insertion_sort.py b/algorithms/impl/insertion_sort.py index f671712..006ab66 100644 --- a/algorithms/impl/insertion_sort.py +++ b/algorithms/impl/insertion_sort.py @@ -2,6 +2,10 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import sys + +from ..algorithm import SortingAlgorithm + def insertion_sort(xs): for i in range(1, len(xs)): j = i @@ -10,11 +14,16 @@ def insertion_sort(xs): j -= 1 return xs +_ALGORITHMS = [ + SortingAlgorithm('insertion_sort', 'Insertion sort', insertion_sort), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) + print(insertion_sort(list(xs))) + if __name__ == '__main__': - import sys - print(insertion_sort(list(map(int, sys.argv[1:])))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('insertion_sort', 'Insertion sort', insertion_sort), - ] + main() diff --git a/algorithms/impl/median.py b/algorithms/impl/median.py index ba51c71..e6b9901 100644 --- a/algorithms/impl/median.py +++ b/algorithms/impl/median.py @@ -2,9 +2,11 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -from algorithms.impl.quicksort import quicksort_random +from heapq import heappush, heappop +import sys -from heapq import * +from ..algorithm import Algorithm +from .quicksort import quicksort_random def calc_median_heaps(xs): cur_median = 0.0 @@ -30,7 +32,7 @@ def calc_median_heaps(xs): cur_median = min_heap[0] return cur_median -def calc_median_sort_first(xs): +def calc_median_sorting(xs): if not xs: return 0.0 quicksort_random(xs) @@ -39,14 +41,18 @@ def calc_median_sort_first(xs): else: return xs[len(xs) // 2 - 1] / 2 + xs[len(xs) // 2] / 2 -if __name__ == '__main__': - import sys - xs = list(map(int, sys.argv[1:])) - print(calc_median_sort_first(list(xs))) +_ALGORITHMS = [ + Algorithm('median_sorting', 'Median value (using explicit sorting)', calc_median_sorting), + Algorithm('median_heaps', 'Median value (using heaps)', calc_median_heaps), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) + print(calc_median_sorting(list(xs))) print(calc_median_heaps(list(xs))) -else: - from algorithms.algorithm import Algorithm - _ALGORITHMS = [ - Algorithm('median_sort_first', 'Median (input is sorted first)', calc_median_sort_first), - Algorithm('median_heaps', 'Median (using heaps)', calc_median_heaps), - ] + +if __name__ == '__main__': + main() diff --git a/algorithms/impl/merge_sort.py b/algorithms/impl/merge_sort.py index 9fa96ec..8d0b573 100644 --- a/algorithms/impl/merge_sort.py +++ b/algorithms/impl/merge_sort.py @@ -2,6 +2,10 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import sys + +from ..algorithm import SortingAlgorithm + def merge(left, right): result = [] l, r = 0, 0 @@ -24,11 +28,16 @@ def merge_sort(xs): mid = len(xs) // 2 return merge(merge_sort(xs[:mid]), merge_sort(xs[mid:])) +_ALGORITHMS = [ + SortingAlgorithm('merge_sort', 'Merge sort', merge_sort), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) + print(merge_sort(list(xs))) + if __name__ == '__main__': - import sys - print(merge_sort(list(map(int, sys.argv[1:])))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('merge_sort', 'Merge sort', merge_sort), - ] + main() diff --git a/algorithms/impl/quicksort.py b/algorithms/impl/quicksort.py index 32100b0..ddc8269 100644 --- a/algorithms/impl/quicksort.py +++ b/algorithms/impl/quicksort.py @@ -3,6 +3,9 @@ # See LICENSE.txt for details. from random import randrange +import sys + +from ..algorithm import SortingAlgorithm def _partition(xs, beg, end, select_pivot): pivot = select_pivot(xs, beg, end) @@ -55,20 +58,24 @@ def quicksort_random(xs): _quicksort(xs, 0, len(xs) - 1, _select_random) return xs -if __name__ == '__main__': - import sys - xs = list(map(int, sys.argv[1:])) +_ALGORITHMS = [ + SortingAlgorithm('quicksort_first', 'Quicksort (first element as pivot)', quicksort_first), + SortingAlgorithm('quicksort_second', 'Quicksort (second element as pivot)', quicksort_second), + SortingAlgorithm('quicksort_middle', 'Quicksort (middle element as pivot)', quicksort_middle), + SortingAlgorithm('quicksort_last', 'Quicksort (last element as pivot)', quicksort_last), + SortingAlgorithm('quicksort_random', 'Quicksort (random element as pivot)', quicksort_random), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) print(quicksort_first(list(xs))) print(quicksort_second(list(xs))) print(quicksort_middle(list(xs))) print(quicksort_last(list(xs))) print(quicksort_random(list(xs))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('quicksort_first', 'Quicksort (first element as pivot)', quicksort_first), - SortingAlgorithm('quicksort_second', 'Quicksort (second element as pivot)', quicksort_second), - SortingAlgorithm('quicksort_middle', 'Quicksort (middle element as pivot)', quicksort_middle), - SortingAlgorithm('quicksort_last', 'Quicksort (last element as pivot)', quicksort_last), - SortingAlgorithm('quicksort_random', 'Quicksort (random element as pivot)', quicksort_random), - ] + +if __name__ == '__main__': + main() diff --git a/algorithms/impl/selection_sort.py b/algorithms/impl/selection_sort.py index 27b319f..d5d11d2 100644 --- a/algorithms/impl/selection_sort.py +++ b/algorithms/impl/selection_sort.py @@ -2,6 +2,10 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import sys + +from ..algorithm import SortingAlgorithm + def selection_sort(xs): for i in range(len(xs) - 1): min_i = i @@ -12,11 +16,16 @@ def selection_sort(xs): xs[i], xs[min_i] = xs[min_i], xs[i] return xs +_ALGORITHMS = [ + SortingAlgorithm('selection_sort', 'Selection sort', selection_sort), +] + +def _parse_args(args=sys.argv): + return list(map(int, args[1:])) + +def main(args=sys.argv): + xs = _parse_args(args) + print(selection_sort(list(xs))) + if __name__ == '__main__': - import sys - print(selection_sort(list(map(int, sys.argv[1:])))) -else: - from algorithms.algorithm import SortingAlgorithm - _ALGORITHMS = [ - SortingAlgorithm('selection_sort', 'Selection sort', selection_sort), - ] + main() diff --git a/algorithms/inputgen.py b/algorithms/inputgen.py new file mode 100644 index 0000000..2659ffc --- /dev/null +++ b/algorithms/inputgen.py @@ -0,0 +1,30 @@ +# Copyright 2016 Egor Tensin +# This file is licensed under the terms of the MIT License. +# See LICENSE.txt for details. + +from array import array +from enum import Enum +from random import seed, sample + +seed() + +class InputKind(Enum): + BEST, AVERAGE, WORST = 'best', 'average', 'worst' + + def __str__(self): + return self.value + +def _gen_input_from(xs): + return array('l', xs) + +def gen_input_for_sorting(n, case=InputKind.AVERAGE): + if n < 0: + raise ValueError('input length must not be a negative number') + if case is InputKind.BEST: + return _gen_input_from(range(n)) + elif case is InputKind.AVERAGE: + return _gen_input_from(sample(range(n), n)) + elif case is InputKind.WORST: + return _gen_input_from(range(n - 1, -1, -1)) + else: + raise NotImplementedError('invalid input kind: ' + str(case)) diff --git a/algorithms/params.py b/algorithms/params.py new file mode 100644 index 0000000..a66f7a2 --- /dev/null +++ b/algorithms/params.py @@ -0,0 +1,110 @@ +# Copyright 2016 Egor Tensin +# This file is licensed under the terms of the MIT License. +# See LICENSE.txt for details. + +from numbers import Integral + +from .inputgen import InputKind +from .plotter import PlotBuilder +from . import registry +from .timer import Timer + +class AlgorithmParameters: + def __init__(self, algorithm, min_len, max_len, + input_kind=InputKind.AVERAGE, iterations=1): + + if isinstance(algorithm, str): + algorithm = registry.get(algorithm) + self.algorithm = algorithm + + self.input_kind = input_kind + + self._min_len = None + self._max_len = None + self.min_len = min_len + self.max_len = max_len + + self._iterations = None + self.iterations = iterations + + @property + def min_len(self): + return self._min_len + + @min_len.setter + def min_len(self, val): + if not isinstance(val, Integral): + raise TypeError('must be an integral value') + val = int(val) + if val < 0: + raise ValueError('must not be a negative number') + if self.max_len is not None and self.max_len < val: + raise ValueError('must not be greater than the maximum length') + self._min_len = val + + @property + def max_len(self): + return self._max_len + + @max_len.setter + def max_len(self, val): + if not isinstance(val, Integral): + raise TypeError('must be an integral value') + val = int(val) + if val < 0: + raise ValueError('must not be a negative number') + if self.min_len is not None and self.min_len > val: + raise ValueError('must not be lesser than the minimum length') + self._max_len = val + + @property + def iterations(self): + return self._iterations + + @iterations.setter + def iterations(self, val): + if not isinstance(val, Integral): + raise TypeError('must be an integral value') + val = int(val) + if val < 1: + raise ValueError('must be a positive number') + self._iterations = val + + def measure_running_time(self): + input_len_range = list(range(self.min_len, self.max_len + 1)) + running_time = [] + for input_len in input_len_range: + input_sample = self.algorithm.gen_input(input_len, self.input_kind) + input_copies = [list(input_sample) for _ in range(self.iterations)] + with Timer(running_time, self.iterations): + for i in range(self.iterations): + self.algorithm.function(input_copies[i]) + return input_len_range, running_time + + @staticmethod + def _format_plot_xlabel(): + return 'Input length' + + @staticmethod + def _format_plot_ylabel(): + return 'Running time (sec)' + + def _format_plot_title(self): + return '{}, {} case'.format( + self.algorithm.display_name, self.input_kind) + + def _format_plot_suptitle(self): + return self.algorithm.display_name + + def plot_running_time(self, output_path=None): + plot_builder = PlotBuilder() + plot_builder.show_grid() + plot_builder.set_xlabel(self._format_plot_xlabel()) + plot_builder.set_ylabel(self._format_plot_ylabel()) + plot_builder.set_title(self._format_plot_title()) + xs, ys = self.measure_running_time() + plot_builder.plot(xs, ys) + if output_path is None: + plot_builder.show() + else: + plot_builder.save(output_path) diff --git a/algorithms/plotter.py b/algorithms/plotter.py new file mode 100644 index 0000000..048894f --- /dev/null +++ b/algorithms/plotter.py @@ -0,0 +1,38 @@ +# Copyright 2016 Egor Tensin +# This file is licensed under the terms of the MIT License. +# See LICENSE.txt for details. + +import matplotlib.pyplot as plt + +class PlotBuilder: + @staticmethod + def set_xlabel(s): + plt.xlabel(s) + + @staticmethod + def set_ylabel(s): + plt.ylabel(s) + + @staticmethod + def show_grid(): + plt.grid() + + @staticmethod + def set_title(s): + plt.title(s) + + @staticmethod + def set_suptitle(s): + plt.suptitle(s) + + @staticmethod + def plot(xs, ys): + plt.plot(xs, ys) + + @staticmethod + def show(): + plt.show() + + @staticmethod + def save(output_path): + plt.savefig(output_path)#, bbox_inches='tight') diff --git a/algorithms/registry.py b/algorithms/registry.py index 8f0469d..0f75dce 100644 --- a/algorithms/registry.py +++ b/algorithms/registry.py @@ -2,16 +2,12 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -import algorithms.impl +from . import impl -def refresh_algorithms(): - algorithms.impl._refresh_algorithms() +_ALL_ALGORITHMS = impl.refresh_algorithms() def get_codenames(): - return algorithms.impl._ALL_ALGORITHMS.keys() - -def iter_algorithms(): - return iter(algorithms.impl._ALL_ALGORITHMS.values()) + return _ALL_ALGORITHMS.keys() def get(codename): - return algorithms.impl._ALL_ALGORITHMS[codename] + return _ALL_ALGORITHMS[codename] diff --git a/algorithms/timer.py b/algorithms/timer.py new file mode 100644 index 0000000..d334b94 --- /dev/null +++ b/algorithms/timer.py @@ -0,0 +1,23 @@ +# Copyright 2016 Egor Tensin +# This file is licensed under the terms of the MIT License. +# See LICENSE.txt for details. + +import gc, time + +def get_timestamp(): + return time.perf_counter() + +class Timer: + def __init__(self, dest, iterations=1): + self._dest = dest + self._iterations = iterations + + def __enter__(self): + gc.disable() + self._start = get_timestamp() + return self + + def __exit__(self, *args): + end = get_timestamp() + gc.enable() + self._dest.append((end - self._start) / self._iterations) diff --git a/plot.bat b/plot.bat index ad7c4e9..f046087 100644 --- a/plot.bat +++ b/plot.bat @@ -27,14 +27,17 @@ set max=%DEFAULT_MAX% ) -plot.py -l "%algorithm%" -a "%min%" -b "%max%" -r "%iterations%" ^ - -i ascending -o "%algorithm%_%iterations%_ascending_%min%_%max%.png" ^ +plot.py "%algorithm%" --min "%min%" --max "%max%" ^ + --iterations "%iterations%" --input best ^ + --output "%algorithm%_%iterations%_best_%min%_%max%.png" ^ || exit /b !errorlevel! -plot.py -l "%algorithm%" -a "%min%" -b "%max%" -r "%iterations%" ^ - -i random -o "%algorithm%_%iterations%_random_%min%_%max%.png" ^ +plot.py "%algorithm%" --min "%min%" --max "%max%" ^ + --iterations "%iterations%" --input average ^ + --output "%algorithm%_%iterations%_average_%min%_%max%.png" ^ || exit /b !errorlevel! -plot.py -l "%algorithm%" -a "%min%" -b "%max%" -r "%iterations%" ^ - -i descending -o "%algorithm%_%iterations%_descending_%min%_%max%.png" ^ +plot.py "%algorithm%" --min "%min%" --max "%max%" ^ + --iterations "%iterations%" --input worst ^ + --output "%algorithm%_%iterations%_worst_%min%_%max%.png" ^ || exit /b !errorlevel! @exit /b diff --git a/plot.py b/plot.py index 6163363..45140f6 100644 --- a/plot.py +++ b/plot.py @@ -2,122 +2,91 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -from enum import Enum -import gc -import pylab -from time import clock - -class OrderType(Enum): - ASCENDING, RANDOM, DESCENDING = 'ascending', 'random', 'descending' - - def __str__(self): - return self.value - - def get_case(self): - if self is OrderType.ASCENDING: - return 'best' - elif self is OrderType.DESCENDING: - return 'worst' - elif self is OrderType.RANDOM: - return 'average' - else: - raise NotImplementedError( - 'unknown "case" for input ordering: \'{}\''.format(self)) - -def get_timestamp(): - return clock() - -def init_clock(): - get_timestamp() - -def gen_input(order, n): - if order is OrderType.ASCENDING: - return list(range(n)) - elif order is OrderType.DESCENDING: - return sorted(range(n), reverse=True) - elif order is OrderType.RANDOM: - from random import sample - return sample(range(n), n) - else: - raise NotImplementedError( - 'invalid input ordering: \'{}\''.format(order)) - -def measure_running_time(algorithm, order, xs_len, iterations): - xs = gen_input(order, xs_len) - xss = [list(xs) for _ in range(iterations)] - algorithm = algorithm.get_function() - gc.disable() - started_at = get_timestamp() - for i in range(iterations): - algorithm(xss[i]) - finished_at = get_timestamp() - gc.enable() - return finished_at - started_at - -def _decorate_plot(algorithm, iterations, order): - pylab.grid() - pylab.xlabel("Input length") - pylab.ylabel('Running time (sec), {} iteration(s)'.format(iterations)) - pylab.title("{}, {} case".format( - algorithm.get_display_name(), order.get_case())) - -def plot_algorithm(algorithm, iterations, order, min_len, max_len, output_path=None): - _decorate_plot(algorithm, iterations, order) - xs_lengths = range(min_len, max_len + 1) - running_time = [] - for xs_len in xs_lengths: - running_time.append(measure_running_time(algorithm, order, xs_len, iterations)) - pylab.plot(xs_lengths, running_time) - if output_path is None: - pylab.show() - else: - pylab.savefig(output_path) +import argparse +import sys -if __name__ == '__main__': - import algorithms.registry - - def natural_number(s): - n = int(s) - if n < 0: - raise argparse.ArgumentTypeError('cannot be negative') - return n - def positive_number(s): - n = int(s) - if n < 1: - raise argparse.ArgumentTypeError('must be positive') - return n - def input_kind(s): - try: - return OrderType(s) - except ValueError: - raise argparse.ArgumentError() - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', '-l', required=True, - choices=algorithms.registry.get_codenames(), - help='specify sorting algorithm to use') - parser.add_argument('--iterations', '-r', - type=positive_number, default=1, +from algorithms.inputgen import InputKind +from algorithms.params import AlgorithmParameters +import algorithms.registry as registry + +_DEFAULT_ITERATIONS = 100 +_DEFAULT_INPUT_KIND = InputKind.AVERAGE +_DEFAULT_MIN_LENGTH = 0 +_DEFAULT_MAX_LENGTH = 200 + +def plot_algorithm(algorithm, input_kind=_DEFAULT_INPUT_KIND, + min_len=_DEFAULT_MIN_LENGTH, + max_len=_DEFAULT_MAX_LENGTH, + iterations=_DEFAULT_ITERATIONS, + output_path=None): + + if isinstance(algorithm, str): + algorithm = registry.get(algorithm) + + params = AlgorithmParameters(algorithm, min_len, max_len, + input_kind=input_kind, + iterations=iterations) + params.plot_running_time(output_path) + +def _parse_natural_number(s): + n = int(s) + if n < 0: + raise argparse.ArgumentTypeError('must not be a negative number') + return n + +def _parse_positive_number(s): + n = int(s) + if n < 1: + raise argparse.ArgumentTypeError('must be positive') + return n + +def _parse_input_kind(s): + try: + return InputKind(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid input_kind: ' + str(s)) + +def _format_algorithm(codename): + return '* {}: {}'.format(codename, registry.get(codename).display_name) + +def _format_available_algorithms(): + descr = 'available algorithms (in the CODENAME: DISPLAY_NAME format):\n' + return descr + '\n'.join(map(_format_algorithm, registry.get_codenames())) + +def _format_description(): + return _format_available_algorithms() + +def _parse_args(args=sys.argv): + parser = argparse.ArgumentParser( + description=_format_description(), + formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument('algorithm', metavar='CODENAME', + choices=registry.get_codenames(), + help='algorithm codename') + parser.add_argument('--iterations', '-r', metavar='N', + type=_parse_positive_number, + default=_DEFAULT_ITERATIONS, help='set number of algorithm iterations') - parser.add_argument('--order', '-i', - choices=tuple(x for x in OrderType), - type=input_kind, default=OrderType.RANDOM, - help='specify input order') - parser.add_argument('--min', '-a', type=natural_number, - required=True, dest='min_len', + parser.add_argument('--input', '-i', dest='input_kind', + choices=InputKind, + type=_parse_input_kind, default=_DEFAULT_INPUT_KIND, + help='specify input kind') + parser.add_argument('--min', '-a', metavar='N', dest='min_len', + type=_parse_natural_number, + default=_DEFAULT_MIN_LENGTH, help='set min input length') - parser.add_argument('--max', '-b', type=natural_number, - required=True, dest='max_len', + parser.add_argument('--max', '-b', metavar='N', dest='max_len', + type=_parse_natural_number, + default=_DEFAULT_MAX_LENGTH, help='set max input length') - parser.add_argument('--output', '-o', dest='output_path', - help='set plot output path') - args = parser.parse_args() - if args.max_len < args.min_len: - parser.error('max input length cannot be less than min input length') - - init_clock() - plot_algorithm(algorithms.registry.get(args.algorithm), - args.iterations, args.order, - args.min_len, args.max_len, - args.output_path) + parser.add_argument('--output', '-o', metavar='PATH', dest='output_path', + help='set plot file path') + + return parser.parse_args(args[1:]) + +def main(args=sys.argv): + plot_algorithm(**vars(_parse_args(args))) + +if __name__ == '__main__': + main() diff --git a/test.py b/test.py index 7acc4bf..6e5379c 100644 --- a/test.py +++ b/test.py @@ -2,52 +2,67 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. -from enum import Enum - -class OrderType(Enum): - ASCENDING, RANDOM, DESCENDING = 'ascending', 'random', 'descending' - - def __str__(self): - return self.value - -def gen_input(kind, n): - if kind is OrderType.ASCENDING: - return list(range(n)) - elif kind is OrderType.DESCENDING: - return sorted(range(n), reverse=True) - elif kind is OrderType.RANDOM: - from random import sample - return sample(range(n), n) - else: - raise NotImplementedError( - 'invalid input ordering: \'{}\''.format(kind)) +from array import array +import argparse +import sys -if __name__ == '__main__': - import algorithms.registry - - def natural_number(s): - n = int(s) - if n < 0: - raise argparse.ArgumentTypeError('cannot be negative') - return n - def order(s): - try: - return OrderType(s) - except ValueError: - raise argparse.ArgumentError() - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--algorithm', '-l', required=True, - choices=algorithms.registry.get_codenames(), - help='specify algorithm codename') - parser.add_argument('--order', '-i', - choices=tuple(x for x in OrderType), - type=order, default=OrderType.RANDOM, - help='specify input order') - parser.add_argument('--length', '-n', - type=natural_number, default=100, +from algorithms.inputgen import InputKind +import algorithms.registry as registry + +_DEFAULT_INPUT_KIND = InputKind.AVERAGE +_DEFAULT_LENGTH = 100 + +def test(algorithm, input_kind=_DEFAULT_INPUT_KIND, length=_DEFAULT_LENGTH): + if isinstance(algorithm, str): + algorithm = registry.get(algorithm) + xs = algorithm.gen_input(length, input_kind) + output = algorithm.function(xs) + if isinstance(output, array): + output = output.tolist() + print(output) + +def _parse_natural_number(s): + n = int(s) + if n < 0: + raise argparse.ArgumentTypeError('must not be a negative number') + return n + +def _parse_input_kind(s): + try: + return InputKind(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid input kind: ' + str(s)) + +def _format_algorithm(codename): + return '* {}: {}'.format(codename, registry.get(codename).display_name) + +def _format_available_algorithms(): + descr = 'available algorithms (in the CODENAME: DISPLAY_NAME format):\n' + return descr + '\n'.join(map(_format_algorithm, registry.get_codenames())) + +def _format_description(): + return _format_available_algorithms() + +def _parse_args(args=sys.argv): + parser = argparse.ArgumentParser( + description=_format_description(), + formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument('algorithm', metavar='CODENAME', + choices=registry.get_codenames(), + help='algorithm codename') + parser.add_argument('--input', '-i', dest='input_kind', + choices=InputKind, + type=_parse_input_kind, default=_DEFAULT_INPUT_KIND, + help='specify input kind') + parser.add_argument('--length', '-l', '-n', metavar='N', + type=_parse_natural_number, default=_DEFAULT_LENGTH, help='set input length') - args = parser.parse_args() - xs = gen_input(args.order, args.length) - print(algorithms.registry.get(args.algorithm).get_function()(xs)) + + return parser.parse_args(args[1:]) + +def main(args=sys.argv): + test(**vars(_parse_args(args))) + +if __name__ == '__main__': + main() -- cgit v1.2.3