diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2016-02-14 00:32:19 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2016-02-14 00:32:19 +0300 |
commit | 734aeb3b184cef8c58f947bfa83dfd225ebaa065 (patch) | |
tree | 35fcfbe2a281abd1c47fe828a81391e132c7c80a /test/file | |
parent | test: refactoring (diff) | |
download | aes-tools-734aeb3b184cef8c58f947bfa83dfd225ebaa065.tar.gz aes-tools-734aeb3b184cef8c58f947bfa83dfd225ebaa065.zip |
test: refactoring
Diffstat (limited to '')
-rw-r--r-- | test/file.py | 138 |
1 files changed, 72 insertions, 66 deletions
diff --git a/test/file.py b/test/file.py index b64b828..13355a0 100644 --- a/test/file.py +++ b/test/file.py @@ -2,6 +2,7 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +from contextlib import contextmanager from datetime import datetime from enum import Enum from glob import iglob as glob @@ -10,7 +11,7 @@ import logging import os import shutil import sys -from tempfile import TemporaryDirectory +from tempfile import NamedTemporaryFile from toolkit import * @@ -33,15 +34,15 @@ def _list_files(root_path, ext): def _list_keys(root_path): return _list_files(root_path, _KEY_EXT) -def _read_line(path): +def _read_first_line(path): with open(path) as f: return f.readline() def _read_key(key_path): - return _read_line(key_path) + return _read_first_line(key_path) def _read_iv(iv_path): - return _read_line(iv_path) + return _read_first_line(iv_path) def _extract_test_name(key_path): return os.path.splitext(os.path.basename(key_path))[0] @@ -52,66 +53,72 @@ def _replace_ext(path, new_ext): def _extract_iv_path(key_path): return _replace_ext(key_path, _IV_EXT) -def _extract_plain_path(key_path): +def _extract_plaintext_path(key_path): return _replace_ext(key_path, _PLAIN_EXT) -def _extract_cipher_path(key_path): +def _extract_ciphertext_path(key_path): return _replace_ext(key_path, _CIPHER_EXT) -class TestCase: - def __init__(self, algorithm, mode, key, plain_path, cipher_path, iv=None): - self.algorithm = algorithm - self.mode = mode - self.key = key - self.plain_path = plain_path - self.cipher_path = cipher_path - self.iv = iv - - def run_encryption_test(self, tools, tmp_dir, force=False): - tmp_dir = os.path.join(tmp_dir, str(self.algorithm), str(self.mode)) - os.makedirs(tmp_dir, 0o777, True) - - logging.info('Running encryption test...') - logging.info('\tPlaintext file path: ' + self.plain_path) - logging.info('\tExpected ciphertext file path: ' + self.cipher_path) - tmp_path = os.path.join(tmp_dir, os.path.basename(self.cipher_path)) +@contextmanager +def _make_output_file(): + with NamedTemporaryFile(delete=False) as tmp_file: + tmp_path = tmp_file.name + yield tmp_path + os.remove(tmp_path) + +def run_encryption_test(tools, algorithm, mode, key, plaintext_path, + ciphertext_path, iv=None, force=False): + logging.info('Running encryption test...') + logging.info('\tPlaintext file path: ' + plaintext_path) + logging.info('\tExpected ciphertext file path: ' + ciphertext_path) + logging.info('\tAlgorithm: ' + str(algorithm)) + logging.info('\tMode: ' + str(mode)) + + with _make_output_file() as tmp_path: logging.info('\tEncrypted file path: ' + tmp_path) - logging.info('\tAlgorithm: {}'.format(self.algorithm)) - logging.info('\tMode: {}'.format(self.mode)) - - tools.run_encrypt_file(self.algorithm, self.mode, self.key, - self.plain_path, tmp_path, self.iv) - if force: - logging.warn('Overwriting expected ciphertext file') - shutil.copy(tmp_path, self.cipher_path) - return TestExitCode.SKIPPED - if filecmp.cmp(self.cipher_path, tmp_path): - return TestExitCode.SUCCESS - else: - logging.error('The encrypted file doesn\'t match the ciphertext file') - return TestExitCode.FAILURE - - def run_decryption_test(self, tools, tmp_dir): - tmp_dir = os.path.join(tmp_dir, str(self.algorithm), str(self.mode)) - os.makedirs(tmp_dir, 0o777, True) - - logging.info('Running decryption test...') - logging.info('\tCiphertext file path: ' + self.cipher_path) - logging.info('\tExpected plaintext file path: ' + self.plain_path) - tmp_path = os.path.join(tmp_dir, os.path.basename(self.plain_path)) + + try: + tools.run_encrypt_file(algorithm, mode, key, plaintext_path, + tmp_path, iv) + if force: + logging.warn('Overwriting expected ciphertext file') + shutil.copy(tmp_path, ciphertext_path) + return TestExitCode.SKIPPED + if filecmp.cmp(ciphertext_path, tmp_path): + return TestExitCode.SUCCESS + else: + logging.error('The encrypted file doesn\'t match the ciphertext file') + return TestExitCode.FAILURE + except Exception as e: + logging.error('Encountered an exception!') + logging.exception(e) + return TestExitCode.ERROR + +def run_decryption_test(tools, algorithm, mode, key, plaintext_path, + ciphertext_path, iv=None): + logging.info('Running decryption test...') + logging.info('\tCiphertext file path: ' + ciphertext_path) + logging.info('\tExpected plaintext file path: ' + plaintext_path) + logging.info('\tAlgorithm: ' + str(algorithm)) + logging.info('\tMode: ' + str(mode)) + + with _make_output_file() as tmp_path: logging.info('\tDecrypted file path: ' + tmp_path) - logging.info('\tAlgorithm: {}'.format(self.algorithm)) - logging.info('\tMode: {}'.format(self.mode)) - - tools.run_decrypt_file(self.algorithm, self.mode, self.key, - self.cipher_path, tmp_path, self.iv) - if filecmp.cmp(tmp_path, self.plain_path): - return TestExitCode.SUCCESS - else: - logging.error('The decrypted file doesn\'t match the plaintext file') - return TestExitCode.FAILURE - -def list_test_cases(suite_dir): + + try: + tools.run_decrypt_file(algorithm, mode, key, ciphertext_path, + tmp_path, iv) + if filecmp.cmp(tmp_path, plaintext_path): + return TestExitCode.SUCCESS + else: + logging.error('The decrypted file doesn\'t match the plaintext file') + return TestExitCode.FAILURE + except Exception as e: + logging.error('Encountered an exception!') + logging.exception(e) + return TestExitCode.ERROR + +def enum_tests(suite_dir): suite_dir = os.path.abspath(suite_dir) logging.info('Suite directory path: ' + suite_dir) for algorithm_dir in _list_dirs(suite_dir): @@ -137,9 +144,9 @@ def list_test_cases(suite_dir): if mode.requires_init_vector(): iv_path = _extract_iv_path(key_path) iv = _read_iv(iv_path) - plain_path = _extract_plain_path(key_path) - cipher_path = _extract_cipher_path(key_path) - yield TestCase(algorithm, mode, key, plain_path, cipher_path, iv) + plaintext_path = _extract_plaintext_path(key_path) + ciphertext_path = _extract_ciphertext_path(key_path) + yield algorithm, mode, key, plaintext_path, ciphertext_path, iv def _build_default_log_path(): return datetime.now().strftime('{}_%Y-%m-%d_%H-%M-%S.log').format( @@ -165,12 +172,11 @@ if __name__ == '__main__': level=logging.DEBUG) tools = Tools(args.path, use_sde=args.sde) - exit_codes = [] - with TemporaryDirectory() as tmp_dir: - for test_case in list_test_cases(args.suite): - exit_codes.append(test_case.run_encryption_test(tools, tmp_dir, args.force)) - exit_codes.append(test_case.run_decryption_test(tools, tmp_dir)) + + for test in enum_tests(args.suite): + exit_codes.append(run_encryption_test(tools, *test, args.force)) + exit_codes.append(run_decryption_test(tools, *test)) logging.info('Test exit codes:') logging.info('\tSkipped: {}'.format(exit_codes.count(TestExitCode.SKIPPED))) |