aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/cmd_line.hpp194
-rw-r--r--src/main.cpp2
-rw-r--r--src/process.hpp27
-rw-r--r--src/string.hpp154
4 files changed, 375 insertions, 2 deletions
diff --git a/src/cmd_line.hpp b/src/cmd_line.hpp
new file mode 100644
index 0000000..3a65ec7
--- /dev/null
+++ b/src/cmd_line.hpp
@@ -0,0 +1,194 @@
+#pragma once
+
+#include "error.hpp"
+#include "string.hpp"
+
+#include <Windows.h>
+#include <shellapi.h>
+
+#include <cstddef>
+
+#include <memory>
+#include <stdexcept>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+class CommandLine
+{
+public:
+ static CommandLine query()
+ {
+ return build_from_string(GetCommandLine());
+ }
+
+ static CommandLine build_from_main(int argc, wchar_t* argv[])
+ {
+ if (argc < 1)
+ throw std::range_error(__FUNCTION__ ": invalid argc value");
+
+ std::wstring argv0{argv[0]};
+ --argc;
+ ++argv;
+
+ std::vector<std::wstring> args;
+ args.reserve(argc);
+
+ for (int i = 0; i < argc; ++i)
+ args.emplace_back(argv[i]);
+
+ return {std::move(argv0), std::move(args)};
+ }
+
+ CommandLine() = default;
+
+ bool has_argv0() const
+ {
+ return !argv0.empty();
+ }
+
+ std::wstring get_argv0() const
+ {
+ return argv0;
+ }
+
+ std::wstring escape_argv0() const
+ {
+ return escape(get_argv0());
+ }
+
+ bool has_args() const
+ {
+ return !get_args().empty();
+ }
+
+ const std::vector<std::wstring>& get_args() const
+ {
+ return args;
+ }
+
+ std::vector<std::wstring> escape_args() const
+ {
+ std::vector<std::wstring> safe;
+ safe.reserve(args.size());
+ for (const auto& arg : args)
+ safe.emplace_back(escape(arg));
+ return safe;
+ }
+
+ static constexpr auto sep = L' ';
+
+ std::wstring join_args() const
+ {
+ return string::join(sep, escape_args());
+ }
+
+ std::wstring join() const
+ {
+ if (!has_argv0())
+ throw std::logic_error(__FUNCTION__ ": doesn't have executable path");
+ std::wostringstream oss;
+ oss << escape_argv0();
+ if (has_args())
+ oss << sep << string::join(sep, escape_args());
+ return oss.str();
+ }
+
+private:
+ static CommandLine build_from_string(std::wstring src)
+ {
+ string::trim(src);
+ if (src.empty())
+ return {};
+
+ int argc = 0;
+ std::unique_ptr<wchar_t*, LocalDelete> argv{CommandLineToArgvW(src.c_str(), &argc)};
+
+ if (argv.get() == NULL)
+ error::raise("CommandLineToArgvW");
+
+ if (argc == 0)
+ return {};
+
+ std::wstring argv0{argv.get()[0]};
+
+ std::vector<std::wstring> args;
+ args.reserve(argc - 1);
+
+ for (int i = 1; i < argc; ++i)
+ args.emplace_back(argv.get()[i]);
+
+ return {std::move(argv0), std::move(args)};
+ }
+
+ inline std::wstring escape_for_cmd(const std::wstring& arg)
+ {
+ static constexpr auto escape_symbol = L'^';
+ static constexpr auto dangerous_symbols = L"!\"%&()<>^|";
+
+ auto safe = escape(arg);
+ string::prefix_with(safe, dangerous_symbols, escape_symbol);
+ return safe;
+ }
+
+ static std::wstring escape(const std::wstring& arg)
+ {
+ std::wstring safe;
+ safe.reserve(arg.length() + 2);
+
+ safe.push_back(L'"');
+
+ for (auto it = arg.cbegin(); it != arg.cend(); ++it)
+ {
+ std::size_t numof_backslashes = 0;
+
+ for (; it != arg.cend() && *it == L'\\'; ++it)
+ ++numof_backslashes;
+
+ if (it == arg.cend())
+ {
+ safe.reserve(safe.capacity() + numof_backslashes);
+ safe.append(2 * numof_backslashes, L'\\');
+ break;
+ }
+
+ switch (*it)
+ {
+ case L'"':
+ safe.reserve(safe.capacity() + numof_backslashes + 1);
+ safe.append(2 * numof_backslashes + 1, L'\\');
+ break;
+
+ default:
+ safe.append(numof_backslashes, L'\\');
+ break;
+ }
+
+ safe.push_back(*it);
+ }
+
+ safe.push_back(L'"');
+ return safe;
+ }
+
+ struct LocalDelete
+ {
+ void operator()(wchar_t* argv[]) const
+ {
+ LocalFree(argv);
+ }
+ };
+
+ CommandLine(std::vector<std::wstring>&& args)
+ : args{std::move(args)}
+ { }
+
+ CommandLine(std::wstring&& argv0, std::vector<std::wstring>&& args = {})
+ : argv0{std::move(argv0)}
+ , args{std::move(args)}
+ { }
+
+ const std::wstring argv0;
+ const std::vector<std::wstring> args;
+};
diff --git a/src/main.cpp b/src/main.cpp
index 3ae2bcc..9d897b7 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -169,7 +169,7 @@ void on_button_elevate_click(HWND wnd)
try
{
- process::runas(process::get_executable_path(), wnd);
+ process::runas_self(wnd);
}
catch (const Error& e)
{
diff --git a/src/process.hpp b/src/process.hpp
index 8279c3c..6fbaf8a 100644
--- a/src/process.hpp
+++ b/src/process.hpp
@@ -5,6 +5,7 @@
#pragma once
+#include "cmd_line.hpp"
#include "error.hpp"
#include <Windows.h>
@@ -29,17 +30,41 @@ namespace process
return buf.data();
}
- void runas(const std::wstring& exe_path, HWND hwnd = NULL, int nShow = SW_NORMAL)
+ std::wstring get_command_line()
{
+ return GetCommandLine();
+ }
+
+ void runas(
+ const CommandLine& cmd_line,
+ HWND hwnd = NULL,
+ int nShow = SW_NORMAL)
+ {
+ static constexpr auto sep = L' ';
+
+ const auto exe_path = cmd_line.has_argv0()
+ ? cmd_line.get_argv0()
+ : get_executable_path();
+
SHELLEXECUTEINFOW info;
ZeroMemory(&info, sizeof(info));
info.cbSize = sizeof(info);
info.lpVerb = L"runas";
info.lpFile = exe_path.c_str();
+ const auto args = cmd_line.join_args();
+ if (!args.empty())
+ info.lpParameters = args.c_str();
info.hwnd = hwnd;
info.nShow = nShow;
if (!ShellExecuteExW(&info))
error::raise("ShellExecuteExW");
}
+
+ void runas_self(
+ HWND hwnd = NULL,
+ int nShow = SW_NORMAL)
+ {
+ runas(CommandLine::query(), hwnd, nShow);
+ }
}
diff --git a/src/string.hpp b/src/string.hpp
new file mode 100644
index 0000000..3621ca6
--- /dev/null
+++ b/src/string.hpp
@@ -0,0 +1,154 @@
+#pragma once
+
+#include <cctype>
+#include <cstddef>
+
+#include <algorithm>
+#include <locale>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+namespace string
+{
+ template <typename Char>
+ inline void ltrim(std::basic_string<Char>& s)
+ {
+ s.erase(s.begin(), std::find_if(s.begin(), s.end(), [] (const Char& c)
+ {
+ return !std::isspace(c);
+ }));
+ }
+
+ template <typename Char>
+ inline void rtrim(std::basic_string<Char>& s)
+ {
+ s.erase(std::find_if(s.rbegin(), s.rend(), [] (const Char& c)
+ {
+ return !std::isspace(c);
+ }).base(), s.end());
+ }
+
+ template <typename Char>
+ inline void trim(std::basic_string<Char>& s)
+ {
+ ltrim(s);
+ rtrim(s);
+ }
+
+ template <typename Char, typename Sep, typename InputIterator>
+ inline std::basic_string<Char> join(
+ const Sep& sep,
+ InputIterator beg,
+ InputIterator end)
+ {
+ std::basic_ostringstream<Char> oss;
+
+ if (beg != end)
+ {
+ oss << *beg;
+ ++beg;
+ }
+
+ for (; beg != end; ++beg)
+ oss << sep << *beg;
+
+ return oss.str();
+ }
+
+ template <typename Char, typename Sep>
+ inline std::basic_string<Char> join(
+ const Sep& sep,
+ const std::vector<std::basic_string<Char>>& args)
+ {
+ return join<Char>(sep, args.cbegin(), args.cend());
+ }
+
+ template <typename Char, typename String, typename = void>
+ struct StringHelper
+ { };
+
+ template <typename Char, typename String>
+ struct StringHelper<Char, String, typename std::enable_if<std::is_same<typename std::decay<typename std::remove_pointer<String>::type>::type, Char>::value>::type>
+ {
+ inline StringHelper(const Char& c)
+ : buf{&c}
+ , len{1}
+ { }
+
+ inline StringHelper(const Char* s)
+ : buf{s}
+ , len{std::char_traits<Char>::length(s)}
+ { }
+
+ inline const Char* buffer() const { return buf; }
+
+ inline std::size_t length() const { return len; }
+
+ private:
+ const Char* const buf;
+ const std::size_t len;
+ };
+
+ template <typename Char, typename String>
+ struct StringHelper<Char, String, typename std::enable_if<std::is_same<String, std::basic_string<Char>>::value>::type>
+ {
+ inline StringHelper(const std::basic_string<Char>& s)
+ : s{s}
+ { }
+
+ inline const Char* buffer() const { return s.c_str(); }
+
+ inline std::size_t length() const { return s.length(); }
+
+ private:
+ const std::basic_string<Char>& s;
+ };
+
+ template <typename Char, typename What, typename By>
+ inline void replace(
+ std::basic_string<Char>& s,
+ const What& what,
+ const By& by)
+ {
+ std::size_t pos = 0;
+
+ const StringHelper<Char, typename std::decay<What>::type> what_helper{what};
+ const StringHelper<Char, typename std::decay<By>::type> by_helper{by};
+
+ const auto what_buf = what_helper.buffer();
+ const auto what_len = what_helper.length();
+
+ const auto by_buf = by_helper.buffer();
+ const auto by_len = by_helper.length();
+
+ while ((pos = s.find(what_buf, pos, what_len)) != std::basic_string<Char>::npos)
+ {
+ s.replace(pos, what_len, by_buf, by_len);
+ pos += by_len;
+ }
+ }
+
+ template <typename Char, typename What>
+ inline void prefix_with(
+ std::basic_string<Char>& s,
+ const What& what,
+ const Char& by)
+ {
+ const StringHelper<Char, typename std::decay<What>::type> what_helper{what};
+
+ const auto what_buf = what_helper.buffer();
+ const auto what_len = what_helper.length();
+
+ std::size_t numof_by = 0;
+
+ for (std::size_t i = 0; i < what_len; ++i)
+ numof_by += std::count(s.cbegin(), s.cend(), what_buf[i]);
+
+ s.reserve(s.capacity() + numof_by);
+
+ for (std::size_t i = 0; i < what_len; ++i)
+ replace(s, what_buf[i], by);
+ }
+}