aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/plot.py
diff options
context:
space:
mode:
Diffstat (limited to 'plot.py')
-rw-r--r--plot.py121
1 files changed, 73 insertions, 48 deletions
diff --git a/plot.py b/plot.py
index 33743aa..99452c4 100644
--- a/plot.py
+++ b/plot.py
@@ -2,52 +2,77 @@
# This file is licensed under the terms of the MIT License.
# See LICENSE.txt for details.
-from time import clock
+from enum import Enum
import gc
+from time import clock
+
+class InputKind(Enum):
+ SORTED, RANDOMIZED, REVERSED = 'sorted', 'randomized', 'reversed'
+
+ def __str__(self):
+ return self.value
+
+class SortingAlgorithm(Enum):
+ BUBBLE = 'bubble'
+ BUBBLE_OPTIMIZED = 'bubble_optimized'
+ HEAP = 'heap'
+ INSERTION = 'insertion'
+ MERGE = 'merge'
+ QUICK_FIRST = 'quick_first'
+ QUICK_SECOND = 'quick_second'
+ QUICK_MIDDLE = 'quick_middle'
+ QUICK_LAST = 'quick_last'
+ QUICK_RANDOM = 'quick_random'
+ SELECTION = 'selection'
-_ALGORITHMS = (
- 'bubble',
- 'bubble_optimized',
- 'heap',
- 'insertion',
- 'merge',
- 'quick_first',
- 'quick_second',
- 'quick_middle',
- 'quick_last',
- 'quick_random',
- 'selection',
-)
+ def __str__(self):
+ return self.value
def _get_context():
+ def natural_number(s):
+ n = int(s)
+ if n < 0:
+ raise argparse.ArgumentTypeError('cannot be negative')
+ return n
+ def positive_number(s):
+ n = int(s)
+ if n < 1:
+ raise argparse.ArgumentTypeError('must be positive')
+ return n
+ def input_kind(s):
+ try:
+ return InputKind(s)
+ except ValueError:
+ raise argparse.ArgumentError()
+ def sorting_algorithm(s):
+ try:
+ return SortingAlgorithm(s)
+ except ValueError:
+ raise argparse.ArgumentError()
import argparse
parser = argparse.ArgumentParser()
- parser.add_argument('--repetitions', '-r', type=int, default=1,
+ parser.add_argument('--repetitions', '-r',
+ type=positive_number, default=1,
help='set number of sorting repetitions')
parser.add_argument('--input', '-i',
- default='randomized', metavar='INPUT',
- choices=('sorted', 'randomized', 'reversed'),
+ choices=tuple(x for x in InputKind),
+ type=input_kind, default=InputKind.RANDOMIZED,
help='choose initial input state')
- parser.add_argument('--algorithm', '-l', metavar='ALGORITHM',
- choices=_ALGORITHMS, required=True,
+ parser.add_argument('--algorithm', '-l', required=True,
+ choices=tuple(x for x in SortingAlgorithm),
+ type=sorting_algorithm,
help='select sorting algorithm to use')
- parser.add_argument('--min', '-a', type=int, required=True,
- help='set min input length',
- dest='min_input_length')
- parser.add_argument('--max', '-b', type=int, required=True,
- help='set max input length',
- dest='max_input_length')
+ parser.add_argument('--min', '-a', type=natural_number,
+ required=True, dest='min_input_length',
+ help='set min input length')
+ parser.add_argument('--max', '-b', type=natural_number,
+ required=True, dest='max_input_length',
+ help='set max input length')
parser.add_argument('--output', '-o', dest='plot_path',
help='set plot output path')
args = parser.parse_args()
- if args.repetitions < 1:
- parser.error('number of repetitions must be > 0')
- if args.min_input_length < 0:
- parser.error('min sequence length must be >= 0')
- if args.max_input_length < 0:
- parser.error('max sequence length must be >= 0')
if args.max_input_length < args.min_input_length:
- parser.error('max sequence length cannot be less than min sequence length')
+ parser.error('max input length cannot be less than min input length')
return args
def get_timestamp():
@@ -57,16 +82,16 @@ def init_clock():
get_timestamp()
def gen_input(args, n):
- if args.input == 'sorted':
+ if args.input is InputKind.SORTED:
return list(range(n))
- elif args.input == 'reversed':
+ elif args.input is InputKind.REVERSED:
return sorted(range(n), reverse=True)
- elif args.input == 'randomized':
+ elif args.input is InputKind.RANDOMIZED:
from random import sample
return sample(range(n), n)
else:
raise NotImplementedError(
- 'unimplemented initial input state \'{}\''.format(args.input))
+ 'invalid initial input state \'{}\''.format(args.input))
def measure_running_time(ctx, algorithm, input_length):
xs = gen_input(ctx, input_length)
@@ -102,42 +127,42 @@ def plot_algorithm(ctx, algorithm):
plt.show()
def plot(ctx):
- if ctx.algorithm == 'bubble':
+ if ctx.algorithm is SortingAlgorithm.BUBBLE:
from bubble_sort import bubble_sort
plot_algorithm(ctx, bubble_sort)
- elif ctx.algorithm == 'bubble_optimized':
+ elif ctx.algorithm is SortingAlgorithm.BUBBLE_OPTIMIZED:
from bubble_sort import bubble_sort_optimized
plot_algorithm(ctx, bubble_sort_optimized)
- elif ctx.algorithm == 'heap':
+ elif ctx.algorithm is SortingAlgorithm.HEAP:
from heapsort import heapsort
plot_algorithm(ctx, heapsort)
- elif ctx.algorithm == 'insertion':
+ elif ctx.algorithm is SortingAlgorithm.INSERTION:
from insertion_sort import insertion_sort
plot_algorithm(ctx, insertion_sort)
- elif ctx.algorithm == 'merge':
+ elif ctx.algorithm is SortingAlgorithm.MERGE:
from merge_sort import merge_sort
plot_algorithm(ctx, merge_sort)
- elif ctx.algorithm == 'quick_first':
+ elif ctx.algorithm is SortingAlgorithm.QUICK_FIRST:
from quicksort import quicksort_first
plot_algorithm(ctx, quicksort_first)
- elif ctx.algorithm == 'quick_second':
+ elif ctx.algorithm is SortingAlgorithm.QUICK_SECOND:
from quicksort import quicksort_second
plot_algorithm(ctx, quicksort_second)
- elif ctx.algorithm == 'quick_middle':
+ elif ctx.algorithm is SortingAlgorithm.QUICK_MIDDLE:
from quicksort import quicksort_middle
plot_algorithm(ctx, quicksort_middle)
- elif ctx.algorithm == 'quick_last':
+ elif ctx.algorithm is SortingAlgorithm.QUICK_LAST:
from quicksort import quicksort_last
plot_algorithm(ctx, quicksort_last)
- elif ctx.algorithm == 'quick_random':
+ elif ctx.algorithm is SortingAlgorithm.QUICK_RANDOM:
from quicksort import quicksort_random
plot_algorithm(ctx, quicksort_random)
- elif ctx.algorithm == 'selection':
+ elif ctx.algorithm is SortingAlgorithm.SELECTION:
from selection_sort import selection_sort
plot_algorithm(ctx, selection_sort)
else:
raise NotImplementedError(
- 'unknown algorithm \'{}\''.format(ctx.algorithm))
+ 'invalid sorting algorithm \'{}\''.format(ctx.algorithm))
if __name__ == '__main__':
ctx = _get_context()