aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/algorithms/params.py
diff options
context:
space:
mode:
Diffstat (limited to 'algorithms/params.py')
-rw-r--r--algorithms/params.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/algorithms/params.py b/algorithms/params.py
new file mode 100644
index 0000000..486cb0e
--- /dev/null
+++ b/algorithms/params.py
@@ -0,0 +1,142 @@
+# Copyright (c) 2016 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "Sorting algorithms" project.
+# For details, see https://github.com/egor-tensin/sorting-algorithms.
+# Distributed under the MIT License.
+
+from enum import Enum
+from numbers import Integral
+
+from .input_kind import InputKind
+from .plotter import PlotBuilder
+from . import registry
+from .timer import Timer
+
+
+class TimeUnits(Enum):
+ SECONDS = 'seconds'
+ MILLISECONDS = 'milliseconds'
+ MICROSECONDS = 'microseconds'
+
+ def get_factor(self):
+ if self is TimeUnits.SECONDS:
+ return 1.
+ if self is TimeUnits.MILLISECONDS:
+ return 1000.
+ if self is TimeUnits.MICROSECONDS:
+ return 1000000.
+ raise NotImplementedError('invalid time units: ' + str(self))
+
+ def __str__(self):
+ return self.value
+
+
+class AlgorithmParameters:
+ def __init__(self, algorithm, min_len, max_len,
+ input_kind=InputKind.AVERAGE, iterations=1):
+
+ if isinstance(algorithm, str):
+ algorithm = registry.get(algorithm)
+ self.algorithm = algorithm
+
+ self.input_kind = input_kind
+
+ self._min_len = None
+ self._max_len = None
+ self.min_len = min_len
+ self.max_len = max_len
+
+ self._iterations = None
+ self.iterations = iterations
+
+ @property
+ def min_len(self):
+ return self._min_len
+
+ @min_len.setter
+ def min_len(self, val):
+ if not isinstance(val, Integral):
+ raise TypeError('must be an integral value')
+ val = int(val)
+ if val < 0:
+ raise ValueError('must be non-negative')
+ if self.max_len is not None and self.max_len < val:
+ raise ValueError('must not be greater than the maximum length')
+ self._min_len = val
+
+ @property
+ def max_len(self):
+ return self._max_len
+
+ @max_len.setter
+ def max_len(self, val):
+ if not isinstance(val, Integral):
+ raise TypeError('must be an integral value')
+ val = int(val)
+ if val < 0:
+ raise ValueError('must be non-negative')
+ if self.min_len is not None and self.min_len > val:
+ raise ValueError('must not be lesser than the minimum length')
+ self._max_len = val
+
+ @property
+ def iterations(self):
+ return self._iterations
+
+ @iterations.setter
+ def iterations(self, val):
+ if not isinstance(val, Integral):
+ raise TypeError('must be an integral value')
+ val = int(val)
+ if val < 1:
+ raise ValueError('must be positive')
+ self._iterations = val
+
+ def measure_running_time(self):
+ input_len_range = list(range(self.min_len, self.max_len + 1))
+ running_time = []
+ for input_len in input_len_range:
+ input_sample = self.algorithm.gen_input(input_len, self.input_kind)
+ input_copies = [list(input_sample) for _ in range(self.iterations)]
+ with Timer(running_time, self.iterations):
+ for i in range(self.iterations):
+ self.algorithm.function(input_copies[i])
+ return input_len_range, running_time
+
+ @staticmethod
+ def _format_plot_xlabel():
+ return 'Input length'
+
+ @staticmethod
+ def _format_plot_ylabel(units):
+ return 'Running time ({})'.format(units)
+
+ def _format_plot_title(self):
+ return '{}, {} case'.format(
+ self.algorithm.display_name, self.input_kind)
+
+ def _format_plot_suptitle(self):
+ return self.algorithm.display_name
+
+ @staticmethod
+ def _derive_time_units(ys):
+ max_y = max(ys)
+ if max_y > 0.1:
+ return TimeUnits.SECONDS
+ if max_y > 0.0001:
+ return TimeUnits.MILLISECONDS
+ return TimeUnits.MICROSECONDS
+
+ def plot_running_time(self, output_path=None):
+ xs, ys = self.measure_running_time()
+ units = self._derive_time_units(ys)
+ ys = [y * units.get_factor() for y in ys]
+
+ plot_builder = PlotBuilder()
+ title = self._format_plot_title()
+ xlabel = self._format_plot_xlabel()
+ ylabel = self._format_plot_ylabel(units)
+ plot_builder.plot(title, xlabel, ylabel, xs, ys)
+ if output_path is None:
+ plot_builder.show()
+ else:
+ plot_builder.save(output_path)