diff --git a/src/libmqttd/network/CMakeLists.txt b/src/libmqttd/network/CMakeLists.txt index 9e1fdb8..94a298a 100644 --- a/src/libmqttd/network/CMakeLists.txt +++ b/src/libmqttd/network/CMakeLists.txt @@ -9,5 +9,6 @@ target_sources(libmqttd add_subdirectory(packet_interface) add_subdirectory(connection) add_subdirectory(disconnection) +add_subdirectory(ping) target_include_directories(libmqttd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/libmqttd/network/connection/connect_packet.hpp b/src/libmqttd/network/connection/connect_packet.hpp index 1cce679..c4b97d5 100644 --- a/src/libmqttd/network/connection/connect_packet.hpp +++ b/src/libmqttd/network/connection/connect_packet.hpp @@ -85,6 +85,8 @@ public: */ inline std::string get_client_id() const { return this->client_id.as_string(); }; + inline uint16_t get_keepalive() const { return this->keepalive; }; + private: /** * @brief Parses the variable header of the packet. diff --git a/src/libmqttd/network/packet_interface/property.hpp b/src/libmqttd/network/packet_interface/property.hpp index af1b951..b27f434 100644 --- a/src/libmqttd/network/packet_interface/property.hpp +++ b/src/libmqttd/network/packet_interface/property.hpp @@ -90,7 +90,7 @@ public: * * @return The size of MQTTProperties. */ - uint16_t size() const { return this->length; }; + uint16_t size() const { return static_cast(this->length); }; /** * @brief Gets the property value for the specified property identifier. diff --git a/src/libmqttd/network/ping/CMakeLists.txt b/src/libmqttd/network/ping/CMakeLists.txt new file mode 100644 index 0000000..fa873e9 --- /dev/null +++ b/src/libmqttd/network/ping/CMakeLists.txt @@ -0,0 +1,9 @@ +FILE(GLOB CPP_FILES CONFIGURE_DEPENDS *.cpp) +FILE(GLOB HPP_FILES CONFIGURE_DEPENDS *.hpp) + +target_sources(libmqttd + PRIVATE ${CPP_FILES} + PUBLIC ${HPP_FILES} +) + +target_include_directories(libmqttd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/libmqttd/network/ping/ping_request_packet.cpp b/src/libmqttd/network/ping/ping_request_packet.cpp new file mode 100644 index 0000000..4d857ea --- /dev/null +++ b/src/libmqttd/network/ping/ping_request_packet.cpp @@ -0,0 +1,25 @@ +#include "ping_request_packet.hpp" +#include "fixed_header.hpp" +#include "packet_interface.hpp" +#include +#include + +PingRequestPacket::PingRequestPacket() : IPacket() { + IPacket::fixed_header.packet_type = PacketType::PINGREQ; + IPacket::fixed_header.remaining_length = 0; +}; + +PingRequestPacket::PingRequestPacket(IPacket &packet) : IPacket(packet) { + IPacket::fixed_header.packet_type = PacketType::PINGREQ; + IPacket::fixed_header.remaining_length = 0; +} + +PingRequestPacket::PingRequestPacket(const std::vector &data) : IPacket(data) { + IPacket::fixed_header.packet_type = PacketType::PINGREQ; + IPacket::fixed_header.remaining_length = 0; +} + +std::vector PingRequestPacket::as_bytes() { + std::vector fixed_header_bytes = IPacket::as_bytes(); + return fixed_header_bytes; +} diff --git a/src/libmqttd/network/ping/ping_request_packet.hpp b/src/libmqttd/network/ping/ping_request_packet.hpp new file mode 100644 index 0000000..9c26016 --- /dev/null +++ b/src/libmqttd/network/ping/ping_request_packet.hpp @@ -0,0 +1,27 @@ +#ifndef INCLUDE_PING_PING_REQUEST_PACKET_HPP_ +#define INCLUDE_PING_PING_REQUEST_PACKET_HPP_ + + +#include "packet_interface.hpp" +class PingRequestPacket : public IPacket { +public: + PingRequestPacket(); + + PingRequestPacket(IPacket &other); + + /** + * @brief Constructs a PingRequestPacket object from byte vector data. + * + * @param data The byte vector containing the packet data. + */ + PingRequestPacket(const std::vector &data); + + /** + * @brief Destructor. + */ + ~PingRequestPacket() = default; + + std::vector as_bytes(); +}; + +#endif // INCLUDE_PING_PING_REQUEST_PACKET_HPP_ diff --git a/src/libmqttd/network/ping/ping_response_packet.cpp b/src/libmqttd/network/ping/ping_response_packet.cpp new file mode 100644 index 0000000..649263d --- /dev/null +++ b/src/libmqttd/network/ping/ping_response_packet.cpp @@ -0,0 +1,25 @@ +#include "ping_response_packet.hpp" +#include "fixed_header.hpp" +#include "packet_interface.hpp" +#include +#include + +PingResponsePacket::PingResponsePacket() : IPacket() { + IPacket::fixed_header.packet_type = PacketType::PINGRESP; + IPacket::fixed_header.remaining_length = 0; +}; + +PingResponsePacket::PingResponsePacket(IPacket &packet) : IPacket(packet) { + IPacket::fixed_header.packet_type = PacketType::PINGRESP; + IPacket::fixed_header.remaining_length = 0; +} + +PingResponsePacket::PingResponsePacket(const std::vector &data) : IPacket(data) { + IPacket::fixed_header.packet_type = PacketType::PINGRESP; + IPacket::fixed_header.remaining_length = 0; +} + +std::vector PingResponsePacket::as_bytes() { + std::vector fixed_header_bytes = IPacket::as_bytes(); + return fixed_header_bytes; +} diff --git a/src/libmqttd/network/ping/ping_response_packet.hpp b/src/libmqttd/network/ping/ping_response_packet.hpp new file mode 100644 index 0000000..48dac9c --- /dev/null +++ b/src/libmqttd/network/ping/ping_response_packet.hpp @@ -0,0 +1,27 @@ +#ifndef INCLUDE_PING_PING_RESPONSE_PACKET_HPP_ +#define INCLUDE_PING_PING_RESPONSE_PACKET_HPP_ + +#include "packet_interface.hpp" +class PingResponsePacket : public IPacket { +public: + PingResponsePacket(); + + PingResponsePacket(IPacket &other); + + /** + * @brief Constructs a PingResponsePacket object from byte vector data. + * + * @param data The byte vector containing the packet data. + */ + PingResponsePacket(const std::vector &data); + + /** + * @brief Destructor. + */ + ~PingResponsePacket() = default; + + std::vector as_bytes(); +}; + + +#endif // INCLUDE_PING_PING_RESPONSE_PACKET_HPP_ diff --git a/src/libmqttd/protocol/session/session.cpp b/src/libmqttd/protocol/session/session.cpp index 569e909..7fd84fb 100644 --- a/src/libmqttd/protocol/session/session.cpp +++ b/src/libmqttd/protocol/session/session.cpp @@ -1,9 +1,11 @@ #include "disconnection/disconnect_packet.hpp" #include "packet_interface.hpp" #include "state_disconnect.hpp" +#include "state_ping_request.hpp" #include "state_waiting_connection.hpp" #include #include +#include #include #include @@ -20,12 +22,16 @@ Session::Session(int socket_fd) { this->socket = socket_fd; this->current_state = nullptr; this->current_packet = nullptr; + this->keepalive_sec = UINT16_MAX; } Session::~Session() { // TODO: It's a normal disconnection here? if (this->is_alive() || this->is_connected()) this->close(DisconnectReasonCode::NORMAL_DISCONNECTION); + + if (this->last_seen_thread.joinable()) + this->last_seen_thread.join(); } void Session::close(const DisconnectReasonCode &reason_code) { @@ -42,8 +48,6 @@ void Session::close() { this->on_disconnect(this); this->is_session_alive = false; - if (this->keepalive_thread.joinable()) - this->keepalive_thread.join(); } std::size_t Session::send(const std::vector &buffer) { @@ -74,6 +78,28 @@ void Session::close_if_not_connected(uint timeout_sec) { this->close(DisconnectReasonCode::MAXIMUM_CONNECT_TIMEOUT); } +void Session::process_last_seen() { + while (is_connected()) { + std::time_t current_timestamp = std::time(0); + if (last_seen > current_timestamp) { + spdlog::error("Session last communication was in the future?! Assuming it was now"); + last_seen = current_timestamp; + } + + std::time_t delta_time = current_timestamp - last_seen; + float max_keepalive_wait = static_cast(keepalive_sec) * 1.5; + if (delta_time > max_keepalive_wait) { + std::stringstream msg; + msg << "Session " << client_id << " timedout on keepalive/ping, finishing it now"; + spdlog::error(msg.str()); + this->close(); + } + + auto sleep_sec = static_cast(max_keepalive_wait); + std::this_thread::sleep_for(std::chrono::seconds(sleep_sec)); + } +} + void Session::listen() { std::vector buffer(buffer_size); @@ -81,6 +107,8 @@ void Session::listen() { this->is_session_alive = true; this->is_session_connected = false; + last_seen_thread = std::thread(&Session::process_last_seen, this); + while (this->is_alive()) { buffer.clear(); buffer.resize(buffer_size); @@ -106,10 +134,23 @@ void Session::listen() { switch (this->current_packet->get_packet_type()) { case PacketType::CONNECT: { this->set_state(StateConnect::get_instance()); + last_seen = std::time(nullptr); break; } case PacketType::DISCONNECT: { this->set_state(StateDisconnect::get_instance()); + last_seen = std::time(nullptr); + break; + } + case PacketType::PINGREQ: { + this->set_state(StatePingRequest::get_instance()); + last_seen = std::time(nullptr); + break; + } + case PacketType::PINGRESP: { + std::stringstream msg; + msg << "Server received a PINGRESP from session " << client_id << " it should never happen"; + spdlog::error(msg.str()); break; } default: { diff --git a/src/libmqttd/protocol/session/session.hpp b/src/libmqttd/protocol/session/session.hpp index cc9f218..d510336 100644 --- a/src/libmqttd/protocol/session/session.hpp +++ b/src/libmqttd/protocol/session/session.hpp @@ -116,9 +116,13 @@ private: IPacket *current_packet; /**< Pointer to the current packet being processed. */ int socket; /**< Socket file descriptor for the session. */ - std::thread keepalive_thread; /**< Thread for handling keepalive messages. */ const unsigned int buffer_size = 65535; /**< Max TCP packet bytes accepted in a single receive call. */ + uint16_t keepalive_sec; /**< Keepalive in seconds to wait for messages in session */ + std::atomic last_seen; /**< Last UNIX timestamp from when this session has communicated.*/ + std::thread last_seen_thread; /**< Thread for handling unresponsive sessions. */ + void process_last_seen(); /**< Process last seen information, finishing the session */ + /** * @brief Joins the keepalive thread. */ diff --git a/src/libmqttd/protocol/session/states/state_connect.cpp b/src/libmqttd/protocol/session/states/state_connect.cpp index 028f4b4..f8fd26f 100644 --- a/src/libmqttd/protocol/session/states/state_connect.cpp +++ b/src/libmqttd/protocol/session/states/state_connect.cpp @@ -38,9 +38,11 @@ void StateConnect::process(Session *session) { IPacket *packet_interface = session->get_current_packet(); ConnectPacket packet(*packet_interface); + // TODO: Validate and auth packet // TODO: Generate a client id if none as provided session->client_id = packet.get_client_id(); + session->keepalive_sec = packet.get_keepalive(); session->on_connect(session); ack_packet.set_reason_code(ConnectReasonCode::SUCCESS); diff --git a/src/libmqttd/protocol/session/states/state_ping_request.cpp b/src/libmqttd/protocol/session/states/state_ping_request.cpp new file mode 100644 index 0000000..1675169 --- /dev/null +++ b/src/libmqttd/protocol/session/states/state_ping_request.cpp @@ -0,0 +1,26 @@ +#include +#include +#include + +ISessionState &StatePingRequest::get_instance() { + static StatePingRequest singleton; + return singleton; +} + +void StatePingRequest::enter(Session *session) { + std::ostringstream log_msg; + log_msg << "Session entered PING REQUEST state"; + spdlog::trace(log_msg.str()); + + session->set_state(StatePingResponse::get_instance()); +} + +void StatePingRequest::exit(Session *session) { + std::ostringstream log_msg; + log_msg << "Session exited PING REQUEST state"; + spdlog::trace(log_msg.str()); +} + +void StatePingRequest::process(Session *session) { + return; +} diff --git a/src/libmqttd/protocol/session/states/state_ping_request.hpp b/src/libmqttd/protocol/session/states/state_ping_request.hpp new file mode 100644 index 0000000..83de699 --- /dev/null +++ b/src/libmqttd/protocol/session/states/state_ping_request.hpp @@ -0,0 +1,21 @@ +#ifndef INCLUDE_STATES_STATE_PING_REQUEST_HPP_ +#define INCLUDE_STATES_STATE_PING_REQUEST_HPP_ + +#include "state_interface.hpp" + +class StatePingRequest : public ISessionState { +public: + void enter(Session *session) final; + void process(Session *session) final; + void exit(Session *session) final; + + static ISessionState &get_instance(); + +private: + StatePingRequest() {}; + + StatePingRequest(const StatePingRequest &); /**< Copy constructor. */ + StatePingRequest &operator=(const StatePingRequest &); /**< Assignment operator. */ +}; + +#endif // INCLUDE_STATES_STATE_PING_REQUEST_HPP_ diff --git a/src/libmqttd/protocol/session/states/state_ping_response.cpp b/src/libmqttd/protocol/session/states/state_ping_response.cpp new file mode 100644 index 0000000..b9295cc --- /dev/null +++ b/src/libmqttd/protocol/session/states/state_ping_response.cpp @@ -0,0 +1,25 @@ +#include +#include +#include "ping_response_packet.hpp" + +ISessionState &StatePingResponse::get_instance() { + static StatePingResponse singleton; + return singleton; +} + +void StatePingResponse::enter(Session *session) { + std::ostringstream log_msg; + log_msg << "Session entered PING RESPONSE state"; + spdlog::trace(log_msg.str()); +} + +void StatePingResponse::exit(Session *session) { + std::ostringstream log_msg; + log_msg << "Session exited PING RESPONSE state"; + spdlog::trace(log_msg.str()); +} + +void StatePingResponse::process(Session *session) { + PingResponsePacket response_packet; + session->send(response_packet.as_bytes()); +} diff --git a/src/libmqttd/protocol/session/states/state_ping_response.hpp b/src/libmqttd/protocol/session/states/state_ping_response.hpp new file mode 100644 index 0000000..e449467 --- /dev/null +++ b/src/libmqttd/protocol/session/states/state_ping_response.hpp @@ -0,0 +1,21 @@ +#ifndef INCLUDE_STATES_STATE_PING_RESPONSE_HPP_ +#define INCLUDE_STATES_STATE_PING_RESPONSE_HPP_ + +#include "state_interface.hpp" + +class StatePingResponse : public ISessionState { +public: + void enter(Session *session) final; + void process(Session *session) final; + void exit(Session *session) final; + + static ISessionState &get_instance(); + +private: + StatePingResponse() {}; + + StatePingResponse(const StatePingResponse &); /**< Copy constructor. */ + StatePingResponse &operator=(const StatePingResponse &); /**< Assignment operator. */ +}; + +#endif // INCLUDE_STATES_STATE_PING_RESPONSE_HPP_ diff --git a/src/version.hpp b/src/version.hpp index 87b1294..cf6b727 100644 --- a/src/version.hpp +++ b/src/version.hpp @@ -1,8 +1,8 @@ #define MQTTD_VERSION_MAJOR 0 #define MQTTD_VERSION_MINOR 0 #define MQTTD_VERSION_PATCH 1 -#define MQTTD_COMMIT_HASH "1a2e4ea6eb773b5db42cc3d37cc84c5b93248fb9" -#define MQTTD_BUILD_TIMESTAMP 1725026213 +#define MQTTD_COMMIT_HASH "f8b32e6edc0e7458239a7d103566683c92b76541" +#define MQTTD_BUILD_TIMESTAMP 1725283801 #define STRINGIFY(x) #x #define TOSTRING(x) STRINGIFY(x)