aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/generate-sql-header.py
blob: eb5e9fef1bf1c005f20da2511dfd5aecca23962e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3

# Copyright (c) 2023 Egor Tensin <egor@tensin.name>
# This file is part of the "cimple" project.
# For details, see https://github.com/egor-tensin/cimple.
# Distributed under the MIT License.

import argparse
from contextlib import contextmanager
from glob import glob
import os
import sys


class Generator:
    def __init__(self, fd, dir):
        self.fd = fd
        if not os.path.isdir(dir):
            raise RuntimeError('must be a directory: ' + dir)
        self.dir = os.path.abspath(dir)
        self.name = os.path.basename(self.dir)

    def write(self, line):
        self.fd.write(f'{line}\n')

    def do(self):
        self.include_guard_start()
        self.include_sql_files()
        self.include_guard_end()

    def include_guard_start(self):
        self.write(f'#ifndef __{self.name.upper()}_SQL_H__')
        self.write(f'#define __{self.name.upper()}_SQL_H__')
        self.write('')

    def include_guard_end(self):
        self.write('')
        self.write(f'#endif')

    def enum_sql_files(self):
        return [os.path.join(self.dir, path) for path in sorted(glob('*.sql', root_dir=self.dir))]

    @property
    def var_name_prefix(self):
        return f'{self.name}'

    def sql_file_to_var_name(self, path):
        name = os.path.splitext(os.path.basename(path))[0]
        return f'{self.var_name_prefix}_schema_{name}'

    @staticmethod
    def sql_file_to_string_literal(path):
        with open(path) as fd:
            sql = fd.read()
        sql = sql.encode().hex().upper()
        sql = ''.join((f'\\x{sql[i:i + 2]}' for i in range(0, len(sql), 2)))
        return sql

    def include_sql_files(self):
        vars = []
        for path in self.enum_sql_files():
            name = self.sql_file_to_var_name(path)
            vars.append(name)
            value = self.sql_file_to_string_literal(path)
            self.write(f'static const char *const {name} = "{value}";')
        self.write('')
        self.write(f'static const char *const {self.var_name_prefix}_schemas[] = {{')
        for var in vars:
            self.write(f'\t{var},')
        self.write('};')


@contextmanager
def open_output(path):
    if path is None:
        yield sys.stdout
    else:
        path = os.path.abspath(path)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'w') as fd:
            yield fd


def parse_args(argv=None):
    if argv is None:
        argv = sys.argv[1:]
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', '--output', metavar='PATH',
                        help='set output file path')
    parser.add_argument('dir', metavar='INPUT_DIR',
                        help='input directory')
    return parser.parse_args()


def main(argv=None):
    args = parse_args(argv)
    with open_output(args.output) as fd:
        generator = Generator(fd, args.dir)
        generator.do()


if __name__ == '__main__':
    main()