From 2357f1382ad5eaf100d9a982e48052a790435665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C5=99emysl=20Janouch?= Date: Fri, 3 Jul 2015 20:32:31 +0200 Subject: [PATCH] degesch: rewrite to use asynchronous I/O --- degesch.c | 730 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 467 insertions(+), 263 deletions(-) diff --git a/degesch.c b/degesch.c index 4934d7d..aee0343 100644 --- a/degesch.c +++ b/degesch.c @@ -1049,12 +1049,42 @@ buffer_destroy (struct buffer *self) // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +enum transport_io_result +{ + TRANSPORT_IO_OK = 0, ///< Completed successfully + TRANSPORT_IO_EOF, ///< Connection shut down by peer + TRANSPORT_IO_ERROR ///< Connection error +}; + +// The only real purpose of this is to abstract away TLS/SSL +struct transport +{ + /// Initialize the transport + bool (*init) (struct server *s, struct error **e); + /// Destroy the user data pointer + void (*cleanup) (struct server *s); + + /// The underlying socket may have become readable, update `read_buffer' + enum transport_io_result (*on_readable) (struct server *s); + /// The underlying socket may have become writeable, flush `write_buffer' + enum transport_io_result (*on_writeable) (struct server *s); + /// Return event mask to use in the poller + int (*get_poll_events) (struct server *s); + + /// Called just before closing the connection from our side + void (*in_before_shutdown) (struct server *s); +}; + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + enum server_state { IRC_DISCONNECTED, ///< Not connected IRC_CONNECTING, ///< Connecting to the server IRC_CONNECTED, ///< Trying to register - IRC_REGISTERED ///< We can chat now + IRC_REGISTERED, ///< We can chat now + IRC_CLOSING, ///< Flushing output before shutdown + IRC_HALF_CLOSED ///< Connection shutdown from our side }; /// Convert an IRC identifier character to lower-case @@ -1079,10 +1109,11 @@ struct server int socket; ///< Socket FD of the server struct str read_buffer; ///< Input yet to be processed - struct poller_fd read_event; ///< We can read from the socket + struct str write_buffer; ///< Outut yet to be be sent out + struct poller_fd socket_event; ///< We can read from the socket - SSL_CTX *ssl_ctx; ///< SSL context - SSL *ssl; ///< SSL connection + struct transport *transport; ///< Transport method + void *transport_data; ///< Transport data // Events: @@ -1177,6 +1208,7 @@ server_init (struct server *self, struct poller *poller) self->socket = -1; str_init (&self->read_buffer); + str_init (&self->write_buffer); self->state = IRC_DISCONNECTED; poller_timer_init (&self->timeout_tmr, poller); @@ -1214,17 +1246,19 @@ server_free (struct server *self) connector_free (self->connector); free (self->connector); } + + if (self->transport + && self->transport->cleanup) + self->transport->cleanup (self); + if (self->socket != -1) { xclose (self->socket); - poller_fd_reset (&self->read_event); + self->socket_event.closed = true; + poller_fd_reset (&self->socket_event); } str_free (&self->read_buffer); - - if (self->ssl) - SSL_free (self->ssl); - if (self->ssl_ctx) - SSL_CTX_free (self->ssl_ctx); + str_free (&self->write_buffer); str_map_free (&self->irc_users); str_map_free (&self->irc_channels); @@ -3080,16 +3114,22 @@ irc_set_casemapping (struct server *s, // --- Core functionality ------------------------------------------------------ -// Most of the core IRC code comes from ZyklonB which is mostly blocking. -// While it's fairly easy to follow, it also stinks. It needs to be rewritten -// to be as asynchronous as possible. See kike.c for reference. - static bool irc_is_connected (struct server *s) { return s->state != IRC_DISCONNECTED && s->state != IRC_CONNECTING; } +static void +irc_update_poller (struct server *s, const struct pollfd *pfd) +{ + int new_events = s->transport->get_poll_events (s); + hard_assert (new_events != 0); + + if (!pfd || pfd->events != new_events) + poller_fd_set (&s->socket_event, new_events); +} + static void irc_cancel_timers (struct server *s) { @@ -3125,115 +3165,6 @@ irc_queue_reconnect (struct server *s) // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -static bool -irc_initialize_ssl_ctx (struct server *s, struct error **e) -{ - // XXX: maybe we should call SSL_CTX_set_options() for some workarounds - - bool verify = get_config_boolean (s->config, "ssl_verify"); - if (!verify) - SSL_CTX_set_verify (s->ssl_ctx, SSL_VERIFY_NONE, NULL); - - const char *ca_file = get_config_string (s->config, "ssl_ca_file"); - const char *ca_path = get_config_string (s->config, "ssl_ca_path"); - - struct error *error = NULL; - if (ca_file || ca_path) - { - if (SSL_CTX_load_verify_locations (s->ssl_ctx, ca_file, ca_path)) - return true; - - error_set (&error, "%s: %s", - "Failed to set locations for the CA certificate bundle", - ERR_reason_error_string (ERR_get_error ())); - goto ca_error; - } - - if (!SSL_CTX_set_default_verify_paths (s->ssl_ctx)) - { - error_set (&error, "%s: %s", - "Couldn't load the default CA certificate bundle", - ERR_reason_error_string (ERR_get_error ())); - goto ca_error; - } - return true; - -ca_error: - if (verify) - { - error_propagate (e, error); - return false; - } - - // Only inform the user if we're not actually verifying - log_server_error (s, s->buffer, "#s", error->message); - error_free (error); - return true; -} - -static bool -irc_initialize_ssl (struct server *s, struct error **e) -{ - const char *error_info = NULL; - s->ssl_ctx = SSL_CTX_new (SSLv23_client_method ()); - if (!s->ssl_ctx) - goto error_ssl_1; - if (!irc_initialize_ssl_ctx (s, e)) - goto error_ssl_2; - - s->ssl = SSL_new (s->ssl_ctx); - if (!s->ssl) - goto error_ssl_2; - - const char *ssl_cert = get_config_string (s->config, "ssl_cert"); - if (ssl_cert) - { - char *path = resolve_config_filename (ssl_cert); - if (!path) - log_server_error (s, s->buffer, - "#s: #s", "Cannot open file", ssl_cert); - // XXX: perhaps we should read the file ourselves for better messages - else if (!SSL_use_certificate_file (s->ssl, path, SSL_FILETYPE_PEM) - || !SSL_use_PrivateKey_file (s->ssl, path, SSL_FILETYPE_PEM)) - log_server_error (s, s->buffer, - "#s: #s", "Setting the SSL client certificate failed", - ERR_error_string (ERR_get_error (), NULL)); - free (path); - } - - SSL_set_connect_state (s->ssl); - if (!SSL_set_fd (s->ssl, s->socket)) - goto error_ssl_3; - // Avoid SSL_write() returning SSL_ERROR_WANT_READ - SSL_set_mode (s->ssl, SSL_MODE_AUTO_RETRY); - - switch (xssl_get_error (s->ssl, SSL_connect (s->ssl), &error_info)) - { - case SSL_ERROR_NONE: - return true; - case SSL_ERROR_ZERO_RETURN: - error_info = "server closed the connection"; - default: - break; - } - -error_ssl_3: - SSL_free (s->ssl); - s->ssl = NULL; -error_ssl_2: - SSL_CTX_free (s->ssl_ctx); - s->ssl_ctx = NULL; -error_ssl_1: - // XXX: these error strings are really nasty; also there could be - // multiple errors on the OpenSSL stack. - if (!error_info) - error_info = ERR_error_string (ERR_get_error (), NULL); - error_set (e, "%s: %s", "could not initialize SSL", error_info); - return false; -} - -// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // As of 2015, everything should be in UTF-8. And if it's not, we'll decode it // as ISO Latin 1. This function should not be called on the whole message. static char * @@ -3272,6 +3203,8 @@ irc_send (struct server *s, const char *format, ...) print_debug ("tried sending a message to a dead server connection"); return; } + if (s->state == IRC_CLOSING) + return; va_list ap; va_start (ap, format); @@ -3293,33 +3226,45 @@ irc_send (struct server *s, const char *format, ...) input_show (&s->ctx->input); } - str_append (&str, "\r\n"); - - if (s->ssl) - { - // TODO: call SSL_get_error() to detect if a clean shutdown has occured - if (SSL_write (s->ssl, str.str, str.len) != (int) str.len) - LOG_FUNC_FAILURE ("SSL_write", - ERR_error_string (ERR_get_error (), NULL)); - } - else if (write (s->socket, str.str, str.len) != (ssize_t) str.len) - LOG_LIBC_FAILURE ("write"); + str_append_str (&s->write_buffer, &str); str_free (&str); + str_append (&s->write_buffer, "\r\n"); + irc_update_poller (s, NULL); } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +static void +irc_real_shutdown (struct server *s) +{ + hard_assert (irc_is_connected (s) && s->state != IRC_HALF_CLOSED); + + if (s->transport + && s->transport->in_before_shutdown) + s->transport->in_before_shutdown (s); + + while (shutdown (s->socket, SHUT_WR) == -1) + if (!soft_assert (errno == EINTR)) + break; + + s->state = IRC_HALF_CLOSED; +} + static void irc_shutdown (struct server *s) { - // Generally non-critical - if (s->ssl) - soft_assert (SSL_shutdown (s->ssl) != -1); - else - soft_assert (shutdown (s->socket, SHUT_WR) == 0); + if (s->state == IRC_CLOSING + || s->state == IRC_HALF_CLOSED) + return; - // TODO: set a timer after which we cut the connection + // TODO: set a timer to cut the connection if we don't receive an EOF + s->state = IRC_CLOSING; + + // Either there's still some data in the write buffer and we wait + // until they're sent, or we send an EOF to the server right away + if (!s->write_buffer.len) + irc_real_shutdown (s); } static void @@ -3372,7 +3317,6 @@ initiate_quit (struct app_context *ctx) if (irc_is_connected (s)) { - // XXX: when we go async, we'll have to flush output buffers first irc_shutdown (s); s->manual_disconnect = true; } @@ -3390,20 +3334,16 @@ on_irc_disconnected (struct server *s) hard_assert (irc_is_connected (s)); // Get rid of the dead socket - if (s->ssl) - { - SSL_free (s->ssl); - s->ssl = NULL; - SSL_CTX_free (s->ssl_ctx); - s->ssl_ctx = NULL; - } + if (s->transport + && s->transport->cleanup) + s->transport->cleanup (s); xclose (s->socket); s->socket = -1; s->state = IRC_DISCONNECTED; - s->read_event.closed = true; - poller_fd_reset (&s->read_event); + s->socket_event.closed = true; + poller_fd_reset (&s->socket_event); // All of our timers have lost their meaning now irc_cancel_timers (s); @@ -3474,131 +3414,392 @@ on_irc_timeout (void *user_data) irc_send (s, "PING :%" PRIi64, (int64_t) time (NULL)); } -// --- Processing server output ------------------------------------------------ +// --- Server I/O -------------------------------------------------------------- static void irc_process_message (const struct irc_message *msg, const char *raw, void *user_data); -enum irc_read_result -{ - IRC_READ_OK, ///< Some data were read successfully - IRC_READ_EOF, ///< The server has closed connection - IRC_READ_AGAIN, ///< No more data at the moment - IRC_READ_ERROR ///< General connection failure -}; - -static enum irc_read_result -irc_fill_read_buffer_ssl (struct server *s, struct str *buf) -{ - int n_read; -start: - n_read = SSL_read (s->ssl, buf->str + buf->len, - buf->alloc - buf->len - 1 /* null byte */); - - const char *error_info = NULL; - switch (xssl_get_error (s->ssl, n_read, &error_info)) - { - case SSL_ERROR_NONE: - buf->str[buf->len += n_read] = '\0'; - return IRC_READ_OK; - case SSL_ERROR_ZERO_RETURN: - return IRC_READ_EOF; - case SSL_ERROR_WANT_READ: - return IRC_READ_AGAIN; - case SSL_ERROR_WANT_WRITE: - { - // Let it finish the handshake as we don't poll for writability; - // any errors are to be collected by SSL_read() in the next iteration - struct pollfd pfd = { .fd = s->socket, .events = POLLOUT }; - soft_assert (poll (&pfd, 1, 0) > 0); - goto start; - } - case XSSL_ERROR_TRY_AGAIN: - goto start; - default: - LOG_FUNC_FAILURE ("SSL_read", error_info); - return IRC_READ_ERROR; - } -} - -static enum irc_read_result -irc_fill_read_buffer (struct server *s, struct str *buf) -{ - ssize_t n_read; -start: - n_read = recv (s->socket, buf->str + buf->len, - buf->alloc - buf->len - 1 /* null byte */, 0); - - if (n_read > 0) - { - buf->str[buf->len += n_read] = '\0'; - return IRC_READ_OK; - } - if (n_read == 0) - return IRC_READ_EOF; - - if (errno == EAGAIN) - return IRC_READ_AGAIN; - if (errno == EINTR) - goto start; - - LOG_LIBC_FAILURE ("recv"); - return IRC_READ_ERROR; -} - static void -on_irc_readable (const struct pollfd *fd, struct server *s) +on_irc_ready (const struct pollfd *pfd, struct server *s) { - if (fd->revents & ~(POLLIN | POLLHUP | POLLERR)) - print_debug ("fd %d: unexpected revents: %d", fd->fd, fd->revents); + struct transport *transport = s->transport; + enum transport_io_result result; - (void) set_blocking (s->socket, false); + if ((result = transport->on_readable (s)) == TRANSPORT_IO_ERROR) + goto error; + bool read_eof = result == TRANSPORT_IO_EOF; + if (s->read_buffer.len >= (1 << 20)) + { + // XXX: this is stupid; if anything, count it in dependence of time + log_server_error (s, s->buffer, + "The IRC server seems to spew out data frantically"); + goto disconnect; + } + if (s->read_buffer.len) + irc_process_buffer (&s->read_buffer, irc_process_message, s); + + if ((result = transport->on_writeable (s)) == TRANSPORT_IO_ERROR) + goto error; + bool write_eof = result == TRANSPORT_IO_EOF; + + // FIXME: this may probably fire multiple times if we're flushing after it, + // we should probably store this information next to the state + if (read_eof || write_eof) + log_server_error (s, s->buffer, "The IRC server closed the connection"); + + // It makes no sense to flush anything if the write needs to read + // and we receive an EOF -> disconnect right away + if (write_eof) + goto disconnect; + + // If we've been asked to flush the write buffer and our job is complete, + // we send an EOF to the server, changing the state to IRC_HALF_CLOSED + if (s->state == IRC_CLOSING && !s->write_buffer.len) + irc_real_shutdown (s); + + if (read_eof) + { + // Both ends closed, we're done + if (s->state == IRC_HALF_CLOSED) + goto disconnect; + + // Otherwise we want to flush the write buffer + irc_shutdown (s); + + // If that went well, we can disconnect now + if (s->state == IRC_HALF_CLOSED) + goto disconnect; + } + + // XXX: shouldn't we rather wait for PONG messages? + irc_reset_connection_timeouts (s); + irc_update_poller (s, pfd); + return; + +error: + log_server_error (s, s->buffer, "Reading from the IRC server failed"); +disconnect: + on_irc_disconnected (s); +} + +// --- Plain transport --------------------------------------------------------- + +static enum transport_io_result +transport_plain_on_readable (struct server *s) +{ struct str *buf = &s->read_buffer; - enum irc_read_result (*fill_buffer)(struct server *, struct str *) - = s->ssl - ? irc_fill_read_buffer_ssl - : irc_fill_read_buffer; - bool disconnected = false; + ssize_t n_read; + while (true) { str_ensure_space (buf, 512); - switch (fill_buffer (s, buf)) + n_read = recv (s->socket, buf->str + buf->len, + buf->alloc - buf->len - 1 /* null byte */, 0); + + if (n_read > 0) { - case IRC_READ_AGAIN: - goto end; - case IRC_READ_ERROR: - log_server_error (s, s->buffer, - "Reading from the IRC server failed"); - disconnected = true; - goto end; - case IRC_READ_EOF: - log_server_error (s, s->buffer, - "The IRC server closed the connection"); - disconnected = true; - goto end; - case IRC_READ_OK: - break; + buf->str[buf->len += n_read] = '\0'; + continue; + } + if (n_read == 0) + return TRANSPORT_IO_EOF; + + if (errno == EAGAIN) + return TRANSPORT_IO_OK; + if (errno == EINTR) + continue; + + LOG_LIBC_FAILURE ("recv"); + return TRANSPORT_IO_ERROR; + } +} + +static enum transport_io_result +transport_plain_on_writeable (struct server *s) +{ + struct str *buf = &s->write_buffer; + ssize_t n_written; + + while (buf->len) + { + n_written = send (s->socket, buf->str, buf->len, 0); + if (n_written >= 0) + { + str_remove_slice (buf, 0, n_written); + continue; } - if (buf->len >= (1 << 20)) + if (errno == EAGAIN) + return TRANSPORT_IO_OK; + if (errno == EINTR) + continue; + + LOG_LIBC_FAILURE ("send"); + return TRANSPORT_IO_ERROR; + } + return TRANSPORT_IO_OK; +} + +static int +transport_plain_get_poll_events (struct server *s) +{ + int events = POLLIN; + if (s->write_buffer.len) + events |= POLLOUT; + return events; +} + +static struct transport g_transport_plain = +{ + .on_readable = transport_plain_on_readable, + .on_writeable = transport_plain_on_writeable, + .get_poll_events = transport_plain_get_poll_events, +}; + +// --- SSL/TLS transport ------------------------------------------------------- + +struct transport_tls_data +{ + SSL_CTX *ssl_ctx; ///< SSL context + SSL *ssl; ///< SSL/TLS connection + bool ssl_rx_want_tx; ///< SSL_read() wants to write + bool ssl_tx_want_rx; ///< SSL_write() wants to read +}; + +static bool +transport_tls_init_ctx (struct server *s, SSL_CTX *ssl_ctx, struct error **e) +{ + bool verify = get_config_boolean (s->config, "ssl_verify"); + if (!verify) + SSL_CTX_set_verify (ssl_ctx, SSL_VERIFY_NONE, NULL); + + const char *ca_file = get_config_string (s->config, "ssl_ca_file"); + const char *ca_path = get_config_string (s->config, "ssl_ca_path"); + + struct error *error = NULL; + if (ca_file || ca_path) + { + if (SSL_CTX_load_verify_locations (ssl_ctx, ca_file, ca_path)) + return true; + + error_set (&error, "%s: %s", + "Failed to set locations for the CA certificate bundle", + ERR_reason_error_string (ERR_get_error ())); + goto ca_error; + } + + if (!SSL_CTX_set_default_verify_paths (ssl_ctx)) + { + error_set (&error, "%s: %s", + "Couldn't load the default CA certificate bundle", + ERR_reason_error_string (ERR_get_error ())); + goto ca_error; + } + + // XXX: maybe we should call SSL_CTX_set_options() for some workarounds + SSL_CTX_set_mode (ssl_ctx, + SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + return true; + +ca_error: + if (verify) + { + error_propagate (e, error); + return false; + } + + // Only inform the user if we're not actually verifying + log_server_error (s, s->buffer, "#s", error->message); + error_free (error); + return true; +} + +static bool +transport_tls_init_cert (struct server *s, SSL *ssl, struct error **e) +{ + const char *ssl_cert = get_config_string (s->config, "ssl_cert"); + if (!ssl_cert) + return true; + + bool result = false; + char *path = resolve_config_filename (ssl_cert); + if (!path) + error_set (e, "%s: %s", "Cannot open file", ssl_cert); + // XXX: perhaps we should read the file ourselves for better messages + else if (!SSL_use_certificate_file (ssl, path, SSL_FILETYPE_PEM) + || !SSL_use_PrivateKey_file (ssl, path, SSL_FILETYPE_PEM)) + error_set (e, "%s: %s", "Setting the SSL client certificate failed", + ERR_error_string (ERR_get_error (), NULL)); + else + result = true; + free (path); + return result; +} + +static bool +transport_tls_init (struct server *s, struct error **e) +{ + const char *error_info = NULL; + SSL_CTX *ssl_ctx = SSL_CTX_new (SSLv23_client_method ()); + if (!ssl_ctx) + goto error_ssl_1; + if (!transport_tls_init_ctx (s, ssl_ctx, e)) + goto error_ssl_2; + + SSL *ssl = SSL_new (ssl_ctx); + if (!ssl) + goto error_ssl_2; + + struct error *error = NULL; + if (!transport_tls_init_cert (s, ssl, &error)) + { + // XXX: is this a reason to abort the connection? + log_server_error (s, s->buffer, "#s", error->message); + error_free (error); + } + + SSL_set_connect_state (ssl); + if (!SSL_set_fd (ssl, s->socket)) + goto error_ssl_3; + + // XXX: maybe set `ssl_rx_want_tx' to force a handshake? + struct transport_tls_data *data = xcalloc (1, sizeof *data); + data->ssl_ctx = ssl_ctx; + data->ssl = ssl; + + s->transport_data = data; + return true; + +error_ssl_3: + SSL_free (ssl); +error_ssl_2: + SSL_CTX_free (ssl_ctx); +error_ssl_1: + // XXX: these error strings are really nasty; also there could be + // multiple errors on the OpenSSL stack. + if (!error_info) + error_info = ERR_error_string (ERR_get_error (), NULL); + error_set (e, "%s: %s", "could not initialize SSL/TLS", error_info); + return false; +} + +static void +transport_tls_cleanup (struct server *s) +{ + struct transport_tls_data *data = s->transport_data; + if (data->ssl) + SSL_free (data->ssl); + if (data->ssl_ctx) + SSL_CTX_free (data->ssl_ctx); + free (data); +} + +static enum transport_io_result +transport_tls_on_readable (struct server *s) +{ + struct transport_tls_data *data = s->transport_data; + if (data->ssl_tx_want_rx) + return TRANSPORT_IO_OK; + + struct str *buf = &s->read_buffer; + data->ssl_rx_want_tx = false; + while (true) + { + str_ensure_space (buf, 512); + int n_read = SSL_read (data->ssl, buf->str + buf->len, + buf->alloc - buf->len - 1 /* null byte */); + + const char *error_info = NULL; + switch (xssl_get_error (data->ssl, n_read, &error_info)) { - log_server_error (s, s->buffer, - "The IRC server seems to spew out data frantically"); - irc_shutdown (s); - goto end; + case SSL_ERROR_NONE: + buf->str[buf->len += n_read] = '\0'; + continue; + case SSL_ERROR_ZERO_RETURN: + return TRANSPORT_IO_EOF; + case SSL_ERROR_WANT_READ: + return TRANSPORT_IO_OK; + case SSL_ERROR_WANT_WRITE: + data->ssl_rx_want_tx = true; + return TRANSPORT_IO_OK; + case XSSL_ERROR_TRY_AGAIN: + continue; + default: + LOG_FUNC_FAILURE ("SSL_read", error_info); + return TRANSPORT_IO_ERROR; } } -end: - (void) set_blocking (s->socket, true); - irc_process_buffer (buf, irc_process_message, s); - - if (disconnected) - on_irc_disconnected (s); - else - irc_reset_connection_timeouts (s); } +static enum transport_io_result +transport_tls_on_writeable (struct server *s) +{ + struct transport_tls_data *data = s->transport_data; + if (data->ssl_rx_want_tx) + return TRANSPORT_IO_OK; + + struct str *buf = &s->write_buffer; + data->ssl_tx_want_rx = false; + while (buf->len) + { + int n_written = SSL_write (data->ssl, buf->str, buf->len); + + const char *error_info = NULL; + switch (xssl_get_error (data->ssl, n_written, &error_info)) + { + case SSL_ERROR_NONE: + str_remove_slice (buf, 0, n_written); + continue; + case SSL_ERROR_ZERO_RETURN: + return TRANSPORT_IO_EOF; + case SSL_ERROR_WANT_WRITE: + return TRANSPORT_IO_OK; + case SSL_ERROR_WANT_READ: + data->ssl_tx_want_rx = true; + return TRANSPORT_IO_OK; + case XSSL_ERROR_TRY_AGAIN: + continue; + default: + LOG_FUNC_FAILURE ("SSL_write", error_info); + return TRANSPORT_IO_ERROR; + } + } + return TRANSPORT_IO_OK; +} + +static int +transport_tls_get_poll_events (struct server *s) +{ + struct transport_tls_data *data = s->transport_data; + + int events = POLLIN; + if (s->write_buffer.len || data->ssl_rx_want_tx) + events |= POLLOUT; + + // While we're waiting for an opposite event, we ignore the original + if (data->ssl_rx_want_tx) events &= ~POLLIN; + if (data->ssl_tx_want_rx) events &= ~POLLOUT; + return events; +} + +static void +transport_tls_in_before_shutdown (struct server *s) +{ + struct transport_tls_data *data = s->transport_data; + (void) SSL_shutdown (data->ssl); +} + +static struct transport g_transport_tls = +{ + .init = transport_tls_init, + .cleanup = transport_tls_cleanup, + .on_readable = transport_tls_on_readable, + .on_writeable = transport_tls_on_writeable, + .get_poll_events = transport_tls_get_poll_events, + .in_before_shutdown = transport_tls_in_before_shutdown, +}; + // --- Connection establishment ------------------------------------------------ static bool @@ -3667,11 +3868,14 @@ irc_finish_connection (struct server *s, int socket) { struct app_context *ctx = s->ctx; + set_blocking (socket, false); s->socket = socket; + s->transport = get_config_boolean (s->config, "ssl") + ? &g_transport_tls + : &g_transport_plain; struct error *e = NULL; - bool use_ssl = get_config_boolean (s->config, "ssl"); - if (use_ssl && !irc_initialize_ssl (s, &e)) + if (s->transport->init && !s->transport->init (s, &e)) { log_server_error (s, s->buffer, "Connection failed: #s", e->message); error_free (e); @@ -3679,21 +3883,21 @@ irc_finish_connection (struct server *s, int socket) xclose (s->socket); s->socket = -1; - irc_queue_reconnect (s); + s->transport = NULL; return; } log_server_status (s, s->buffer, "Connection established"); s->state = IRC_CONNECTED; - poller_fd_init (&s->read_event, &ctx->poller, s->socket); - s->read_event.dispatcher = (poller_fd_fn) on_irc_readable; - s->read_event.user_data = s; + poller_fd_init (&s->socket_event, &ctx->poller, s->socket); + s->socket_event.dispatcher = (poller_fd_fn) on_irc_ready; + s->socket_event.user_data = s; - poller_fd_set (&s->read_event, POLLIN); + irc_update_poller (s, NULL); irc_reset_connection_timeouts (s); - irc_register (s); + refresh_prompt (s->ctx); }