From 94308ce35b636c6941a4a11985aec2189cb4b9e8 Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Sun, 14 May 2023 17:09:53 +0200 Subject: msg: enforce at least one word Also, move some stuff to net.c where it belongs. --- src/command.c | 4 +-- src/msg.c | 90 ++++++++++++----------------------------------------------- src/msg.h | 11 ++++---- src/net.c | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/net.h | 3 ++ 5 files changed, 111 insertions(+), 79 deletions(-) diff --git a/src/command.c b/src/command.c index 8e24e7f..5661edc 100644 --- a/src/command.c +++ b/src/command.c @@ -113,12 +113,12 @@ int cmd_dispatcher_handle_msg(const struct cmd_dispatcher *dispatcher, int conn_ if (numof_words == 0) goto unknown; - const char **words = msg_get_words(request); + const char *actual_cmd = msg_get_first_word(request); for (size_t i = 0; i < dispatcher->numof_cmds; ++i) { struct cmd_desc *cmd = &dispatcher->cmds[i]; - if (strcmp(cmd->name, words[0])) + if (strcmp(cmd->name, actual_cmd)) continue; ret = cmd->handler(conn_fd, request, dispatcher->ctx, &response); diff --git a/src/msg.c b/src/msg.c index 583c5e5..4da0154 100644 --- a/src/msg.c +++ b/src/msg.c @@ -28,6 +28,11 @@ const char **msg_get_words(const struct msg *msg) return msg->argv; } +const char *msg_get_first_word(const struct msg *msg) +{ + return msg->argv[0]; +} + int msg_success(struct msg **msg) { static const char *argv[] = {"success", NULL}; @@ -127,6 +132,11 @@ int msg_from_argv(struct msg **_msg, const char **argv) for (const char **s = argv; *s; ++s) ++msg->argc; + if (!msg->argc) { + log_err("A message must contain at least one string\n"); + goto free; + } + ret = msg_copy_argv(msg, argv); if (ret < 0) goto free; @@ -140,78 +150,14 @@ free: return -1; } -static uint32_t calc_buf_size(const struct msg *msg) -{ - uint32_t len = 0; - for (size_t i = 0; i < msg->argc; ++i) - len += strlen(msg->argv[i]) + 1; - return len; -} - -static size_t calc_argv_len(const void *buf, size_t len) -{ - size_t argc = 0; - for (const char *it = buf; it < (const char *)buf + len; it += strlen(it) + 1) - ++argc; - return argc; -} - -static void argv_pack(char *dest, const struct msg *msg) -{ - for (size_t i = 0; i < msg->argc; ++i) { - strcpy(dest, msg->argv[i]); - dest += strlen(msg->argv[i]) + 1; - } -} - -static int argv_unpack(struct msg *msg, const char *src) -{ - size_t copied = 0; - - msg->argv = calloc(msg->argc + 1, sizeof(const char *)); - if (!msg->argv) { - log_errno("calloc"); - return -1; - } - - for (copied = 0; copied < msg->argc; ++copied) { - msg->argv[copied] = strdup(src); - if (!msg->argv[copied]) { - log_errno("strdup"); - goto free; - } - - src += strlen(msg->argv[copied]) + 1; - } - - return 0; - -free: - for (size_t i = 0; i < copied; ++i) { - free((char *)msg->argv[i]); - } - - msg_free(msg); - - return -1; -} - int msg_send(int fd, const struct msg *msg) { struct buf *buf = NULL; int ret = 0; - uint32_t size = calc_buf_size(msg); - char *data = malloc(size); - if (!data) { - log_errno("malloc"); - return -1; - } - argv_pack(data, msg); - - ret = buf_create(&buf, data, size); + ret = buf_pack_strings(&buf, msg->argc, msg->argv); if (ret < 0) - goto free_data; + return ret; ret = net_send_buf(fd, buf); if (ret < 0) @@ -220,9 +166,6 @@ int msg_send(int fd, const struct msg *msg) destroy_buf: buf_destroy(buf); -free_data: - free(data); - return ret; } @@ -257,12 +200,15 @@ int msg_recv(int fd, struct msg **_msg) goto destroy_buf; } - msg->argc = calc_argv_len(buf_get_data(buf), buf_get_size(buf)); - - ret = argv_unpack(msg, buf_get_data(buf)); + ret = buf_unpack_strings(buf, &msg->argc, &msg->argv); if (ret < 0) goto free_msg; + if (!msg->argc) { + log_err("A message must contain at least one string\n"); + goto free_msg; + } + *_msg = msg; goto destroy_buf; diff --git a/src/msg.h b/src/msg.h index e6541f4..158c5e6 100644 --- a/src/msg.h +++ b/src/msg.h @@ -12,8 +12,14 @@ struct msg; +int msg_from_argv(struct msg **, const char **argv); +void msg_free(struct msg *); + +int msg_copy(struct msg **, const struct msg *); + size_t msg_get_length(const struct msg *); const char **msg_get_words(const struct msg *); +const char *msg_get_first_word(const struct msg *); int msg_success(struct msg **); int msg_error(struct msg **); @@ -21,11 +27,6 @@ int msg_error(struct msg **); int msg_is_success(const struct msg *); int msg_is_error(const struct msg *); -int msg_copy(struct msg **, const struct msg *); -void msg_free(struct msg *); - -int msg_from_argv(struct msg **, const char **argv); - int msg_recv(int fd, struct msg **); int msg_send(int fd, const struct msg *); diff --git a/src/net.c b/src/net.c index 0a6ef85..c3a4b5e 100644 --- a/src/net.c +++ b/src/net.c @@ -247,6 +247,88 @@ void *buf_get_data(const struct buf *buf) return buf->data; } +static size_t count_strings(const void *_data, size_t size) +{ + const unsigned char *data = (const unsigned char *)_data; + const unsigned char *it = memchr(data, '\0', size); + + size_t numof_strings = 0; + while (it) { + it = memchr(it + 1, '\0', size - (it - data) - 1); + ++numof_strings; + } + + return numof_strings; +} + +int buf_pack_strings(struct buf **_buf, size_t argc, const char **argv) +{ + struct buf *buf = malloc(sizeof(struct buf)); + if (!buf) { + log_errno("malloc"); + return -1; + } + + buf->size = 0; + for (size_t i = 0; i < argc; ++i) + buf->size += strlen(argv[i]) + 1; + + buf->data = malloc(buf->size); + if (!buf->data) { + log_errno("malloc"); + goto free_buf; + } + + char *it = (char *)buf->data; + for (size_t i = 0; i < argc; ++i) { + it = stpcpy(it, argv[i]) + 1; + } + + *_buf = buf; + return 0; + +free_buf: + free(buf); + + return -1; +} + +int buf_unpack_strings(const struct buf *buf, size_t *_argc, const char ***_argv) +{ + size_t argc = count_strings(buf->data, buf->size); + size_t copied = 0; + + const char **argv = calloc(argc + 1, sizeof(const char *)); + if (!argv) { + log_errno("calloc"); + return -1; + } + + const char *it = (const char *)buf->data; + for (copied = 0; copied < argc; ++copied) { + argv[copied] = strdup(it); + if (!argv[copied]) { + log_errno("strdup"); + goto free; + } + + it += strlen(argv[copied]) + 1; + } + + *_argc = argc; + *_argv = argv; + return 0; + +free: + for (size_t i = 0; i < copied; ++i) { + free((char *)argv[i]); + } + + free(argv); + + return -1; +} + int net_send_buf(int fd, const struct buf *buf) { int ret = 0; diff --git a/src/net.h b/src/net.h index 84968cb..a8acb57 100644 --- a/src/net.h +++ b/src/net.h @@ -26,6 +26,9 @@ void buf_destroy(struct buf *); uint32_t buf_get_size(const struct buf *); void *buf_get_data(const struct buf *); +int buf_pack_strings(struct buf **, size_t argc, const char **argv); +int buf_unpack_strings(const struct buf *, size_t *argc, const char ***argv); + int net_send_buf(int fd, const struct buf *); int net_recv_buf(int fd, struct buf **); -- cgit v1.2.3