diff options
-rw-r--r-- | response.cpp | 9 | ||||
-rw-r--r-- | tests/test-webserver.cpp | 200 | ||||
-rw-r--r-- | websocket.h | 38 |
3 files changed, 227 insertions, 20 deletions
diff --git a/response.cpp b/response.cpp index eeda8d0..ed30550 100644 --- a/response.cpp +++ b/response.cpp @@ -40,7 +40,7 @@ public: { } - // GetTarget() == GetPluginPath() + GetRelativePath() + // GetTarget() == GetPluginPath() + "/" + GetRelativePath() const Path& GetPath() const {return m_path;} // GetPluginPath w/ configured params as struct @@ -337,19 +337,14 @@ response_type response::generate_response(request_type& req, Server& server) std::string response::get_websocket_address(request_type& req, Server& server) { try { - std::cout << "DEBUG0" << std::endl; - std::cout << "DEBUG0: " << req.target() << std::endl; RequestContext req_ctx{req, server}; // can throw std::out_of_range - std::cout << "DEBUG1" << std::endl; if (req_ctx.GetPluginName() != "websocket") { std::cout << "Bad plugin configured for websocket request: " << req_ctx.GetPluginName() << std::endl; return {}; } - std::cout << "DEBUG2" << std::endl; - return req_ctx.GetDocRoot(); // Configured "path" in config: host:port for websocket - std::cout << "DEBUG3" << std::endl; + return req_ctx.GetDocRoot() + "/" + req_ctx.GetRelativePath(); // Configured "path" in config: host:port/relative_path for websocket } catch (const std::exception& ex) { std::cout << "No matching configured target websocket found: " << ex.what() << std::endl; diff --git a/tests/test-webserver.cpp b/tests/test-webserver.cpp index 10f6dca..602eb77 100644 --- a/tests/test-webserver.cpp +++ b/tests/test-webserver.cpp @@ -29,6 +29,8 @@ #include <exception> #include <filesystem> #include <iostream> +#include <memory> +#include <mutex> #include <sstream> #include <stdexcept> #include <string> @@ -38,11 +40,14 @@ #include <signal.h> #include <sys/wait.h> #include <unistd.h> +#include <sys/mman.h> +#include <sys/types.h> #include <libreichwein/file.h> #include <libreichwein/process.h> #include "webserver.h" +#include "response.h" using namespace std::string_literals; namespace fs = std::filesystem; @@ -436,9 +441,19 @@ BOOST_DATA_TEST_CASE_F(Fixture, http_get_file_not_found, data::make({false, true // Test server class WebsocketServerProcess { + // shared data between Unix processes + struct shared_data_t { + std::mutex mutex; + char subprotocol[1024]{}; + char target[1024]{}; + }; + public: WebsocketServerProcess() { + m_shared = std::unique_ptr<shared_data_t, std::function<void(shared_data_t*)>>( + (shared_data_t*)mmap(NULL, sizeof(shared_data_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0), + [this](shared_data_t*){munmap(m_shared.get(), sizeof(shared_data_t));}); start(); } @@ -447,7 +462,6 @@ public: stop(); } - // Echoes back all received WebSocket messages void do_session(boost::asio::ip::tcp::socket socket) { try @@ -463,18 +477,27 @@ public: std::string("Reichwein.IT Test Websocket Server")); })); - // Accept the websocket handshake - ws.accept(); + boost::beast::http::request_parser<boost::beast::http::string_body> parser; + request_type req; + boost::beast::flat_buffer buffer; + + boost::beast::http::read(ws.next_layer(), buffer, parser); + req = parser.get(); + { + std::lock_guard lock{m_shared->mutex}; + strncpy(m_shared->subprotocol, std::string{req[http::field::sec_websocket_protocol]}.data(), sizeof(m_shared->subprotocol)); + strncpy(m_shared->target, std::string{req.target()}.data(), sizeof(m_shared->target)); + } + + ws.accept(req); for(;;) { - // This buffer will hold the incoming message boost::beast::flat_buffer buffer; - // Read a message ws.read(buffer); - // Echo the message back + // Reply with <request>: <counter> ws.text(ws.got_text()); std::string data(boost::asio::buffers_begin(buffer.data()), boost::asio::buffers_end(buffer.data())); data += ": " + std::to_string(m_count++); @@ -562,9 +585,22 @@ public: m_pid = 0; } + std::string subprotocol() + { + std::lock_guard lock{m_shared->mutex}; + return m_shared->subprotocol; + } + + std::string target() + { + std::lock_guard lock{m_shared->mutex}; + return m_shared->target; + } + private: int m_pid{}; int m_count{}; + std::unique_ptr<shared_data_t, std::function<void(shared_data_t*)>> m_shared; }; // class WebsocketServerProcess BOOST_FIXTURE_TEST_CASE(websocket, Fixture) @@ -658,7 +694,8 @@ BOOST_FIXTURE_TEST_CASE(websocket, Fixture) // Update the host_ string. This will provide the value of the // Host HTTP header during the WebSocket handshake. // See https://tools.ietf.org/html/rfc7230#section-5.4 - host = "[" + host + "]"; + if (host == "::1") + host = "[" + host + "]"; host += ':' + std::to_string(ep.port()); // Perform the SSL handshake @@ -707,3 +744,152 @@ BOOST_FIXTURE_TEST_CASE(websocket, Fixture) BOOST_REQUIRE(websocketProcess.is_running()); } +BOOST_FIXTURE_TEST_CASE(websocket_subprotocol, Fixture) +{ + std::string webserver_config{R"CONFIG(<webserver> + <user>www-data</user> + <group>www-data</group> + <threads>10</threads> + <statisticspath>stats.db</statisticspath> + <plugin-directory>../plugins</plugin-directory> + <sites> + <site> + <name>localhost</name> + <host>ip6-localhost</host> + <host>localhost</host> + <host>127.0.0.1</host> + <host>[::1]</host> + <path requested="/"> + <plugin>websocket</plugin> + <target>::1:8765</target> + </path> + <certpath>testchain.pem</certpath> + <keypath>testkey.pem</keypath> + </site> + </sites> + <sockets> + <socket> + <address>127.0.0.1</address> + <port>8080</port> + <protocol>http</protocol> + <site>localhost</site> + </socket> + <socket> + <address>::1</address> + <port>8080</port> + <protocol>http</protocol> + <site>localhost</site> + </socket> + <socket> + <address>127.0.0.1</address> + <port>8081</port> + <protocol>https</protocol> + <site>localhost</site> + </socket> + <socket> + <address>::1</address> + <port>8081</port> + <protocol>https</protocol> + <site>localhost</site> + </socket> + </sockets> +</webserver> +)CONFIG"}; + WebserverProcess serverProcess{webserver_config}; + BOOST_REQUIRE(serverProcess.is_running()); + + WebsocketServerProcess websocketProcess; + BOOST_REQUIRE(websocketProcess.is_running()); + + std::string host = "::1"; + auto const port = "8081" ; + auto const text = "request1"; + + // The io_context is required for all I/O + boost::asio::io_context ioc; + + // The SSL context is required, and holds certificates + boost::asio::ssl::context ctx{boost::asio::ssl::context::tlsv13_client}; + + // This holds the root certificate used for verification + load_root_certificates(ctx); + + // These objects perform our I/O + boost::asio::ip::tcp::resolver resolver{ioc}; + boost::beast::websocket::stream<boost::beast::ssl_stream<boost::asio::ip::tcp::socket>> ws{ioc, ctx}; + + // Look up the domain name + auto const results = resolver.resolve(host, port); + + // Make the connection on the IP address we get from a lookup + auto ep = boost::asio::connect(get_lowest_layer(ws), results); + + // Set SNI Hostname (many hosts need this to handshake successfully) + if(! SSL_set_tlsext_host_name(ws.next_layer().native_handle(), host.c_str())) + throw boost::beast::system_error( + boost::beast::error_code( + static_cast<int>(::ERR_get_error()), + boost::asio::error::get_ssl_category()), + "Failed to set SNI Hostname"); + + // Update the host_ string. This will provide the value of the + // Host HTTP header during the WebSocket handshake. + // See https://tools.ietf.org/html/rfc7230#section-5.4 + if (host == "::1") + host = "[" + host + "]"; + host += ':' + std::to_string(ep.port()); + + // Perform the SSL handshake + ws.next_layer().handshake(boost::asio::ssl::stream_base::client); + + // Set a decorator to change the User-Agent of the handshake + ws.set_option(boost::beast::websocket::stream_base::decorator( + [](boost::beast::websocket::request_type& req) + { + req.set(boost::beast::http::field::user_agent, + std::string("Reichwein.IT Test Websocket Client")); + })); + + ws.set_option(boost::beast::websocket::stream_base::decorator( + [](boost::beast::websocket::request_type& req) + { + req.set(boost::beast::http::field::sec_websocket_protocol, "protocol1"); + })); + + // Perform the websocket handshake + ws.handshake(host, "/path1/target1"); + + // Send the message + ws.write(boost::asio::buffer(std::string(text))); + + // This buffer will hold the incoming message + boost::beast::flat_buffer buffer; + + // Read a message into our buffer + ws.read(buffer); + std::string data(boost::asio::buffers_begin(buffer.data()), boost::asio::buffers_end(buffer.data())); + BOOST_CHECK_EQUAL(data, "request1: 0"); + + buffer.consume(buffer.size()); + + ws.write(boost::asio::buffer(std::string(text))); + ws.read(buffer); + data = std::string(boost::asio::buffers_begin(buffer.data()), boost::asio::buffers_end(buffer.data())); + BOOST_CHECK_EQUAL(data, "request1: 1"); + + buffer.consume(buffer.size()); + + ws.write(boost::asio::buffer(std::string(text))); + ws.read(buffer); + data = std::string(boost::asio::buffers_begin(buffer.data()), boost::asio::buffers_end(buffer.data())); + BOOST_CHECK_EQUAL(data, "request1: 2"); + + // Close the WebSocket connection + ws.close(boost::beast::websocket::close_code::normal); + + BOOST_CHECK_EQUAL(websocketProcess.subprotocol(), "protocol1"); + BOOST_CHECK_EQUAL(websocketProcess.target(), "/path1/target1"); + BOOST_REQUIRE(websocketProcess.is_running()); + BOOST_REQUIRE(serverProcess.is_running()); +} + diff --git a/websocket.h b/websocket.h index 85492f2..951155e 100644 --- a/websocket.h +++ b/websocket.h @@ -47,6 +47,8 @@ class websocket_session: public std::enable_shared_from_this<websocket_session> boost::beast::flat_buffer buffer_out_; std::string host_; std::string port_; + std::string subprotocol_; + std::string relative_target_; public: explicit websocket_session(boost::asio::io_context& ioc, beast::ssl_stream<beast::tcp_stream>&& stream, const std::string& websocket_address): @@ -55,17 +57,32 @@ public: ws_in_(std::move(stream)), ws_app_(boost::asio::make_strand(ioc_)), host_{}, - port_{} + port_{}, + subprotocol_{}, + relative_target_{} { // Parse websocket address host:port : - auto pos{websocket_address.find_last_of(':')}; + auto colon_pos{websocket_address.find_last_of(':')}; - if (pos == std::string::npos) + if (colon_pos == std::string::npos) { + std::cerr << "Warning: Bad websocket address (colon missing): " << websocket_address << std::endl; return; + } + + auto slash_pos{websocket_address.find('/')}; + if (slash_pos == std::string::npos) { + std::cerr << "Warning: Bad websocket address (slash missing): " << websocket_address << std::endl; + return; + } + if (slash_pos <= colon_pos) { + std::cerr << "Warning: Bad websocket address: " << websocket_address << std::endl; + return; + } - host_ = websocket_address.substr(0, pos); - port_ = websocket_address.substr(pos + 1); + host_ = websocket_address.substr(0, colon_pos); + port_ = websocket_address.substr(colon_pos + 1, slash_pos - (colon_pos + 1)); + relative_target_ = websocket_address.substr(slash_pos); } // @@ -89,6 +106,9 @@ public: std::string{"Reichwein.IT Webserver"}); })); + // Forward subprotocol from request to target websocket + subprotocol_ = std::string{req[http::field::sec_websocket_protocol]}; + // Accept the websocket handshake ws_in_.async_accept( req, @@ -135,8 +155,14 @@ private: { req.set(boost::beast::http::field::user_agent, "Reichwein.IT Webserver Websocket client"); })); + + ws_app_.set_option(boost::beast::websocket::stream_base::decorator( + [this](boost::beast::websocket::request_type& req) + { + req.set(boost::beast::http::field::sec_websocket_protocol, subprotocol_); + })); - ws_app_.async_handshake(host_, "/", + ws_app_.async_handshake(host_, relative_target_, beast::bind_front_handler(&websocket_session::on_handshake_app, shared_from_this())); } |