diff --git a/degesch.c b/degesch.c index 71f0e02..6b629d6 100644 --- a/degesch.c +++ b/degesch.c @@ -8089,8 +8089,36 @@ struct lua_connection struct lua_plugin *plugin; ///< The plugin we belong to struct poller_fd socket_event; ///< Socket is ready int socket_fd; ///< Underlying connected socket + + bool got_eof; ///< Half-closed by remote host + bool closing; ///< We're closing the connection + + struct str read_buffer; ///< Read buffer + struct str write_buffer; ///< Write buffer }; +static void +lua_connection_update_poller (struct lua_connection *self) +{ + poller_fd_set (&self->socket_event, + self->write_buffer.len ? (POLLIN | POLLOUT) : POLLIN); +} + +static int +lua_connection_send (lua_State *L) +{ + struct lua_connection *self = + luaL_checkudata (L, 1, XLUA_CONNECTION_METATABLE); + if (self->socket_fd == -1) + return luaL_error (L, "connection has been closed"); + + size_t len; + const char *s = luaL_checklstring (L, 2, &len); + str_append_data (&self->write_buffer, s, len); + lua_connection_update_poller (self); + return 0; +} + static void lua_connection_discard (struct lua_connection *self) { @@ -8099,6 +8127,9 @@ lua_connection_discard (struct lua_connection *self) poller_fd_reset (&self->socket_event); xclose (self->socket_fd); self->socket_fd = -1; + + str_free (&self->read_buffer); + str_free (&self->write_buffer); } // Connection is dead, we don't need to hold onto any resources anymore @@ -8107,6 +8138,22 @@ lua_connection_discard (struct lua_connection *self) static int lua_connection_close (lua_State *L) +{ + struct lua_connection *self = + luaL_checkudata (L, 1, XLUA_CONNECTION_METATABLE); + if (self->socket_fd != -1) + { + self->closing = true; + (void) shutdown (self->socket_fd, SHUT_RD); + + if (!self->write_buffer.len) + lua_connection_discard (self); + } + return 0; +} + +static int +lua_connection_gc (lua_State *L) { lua_connection_discard (luaL_checkudata (L, 1, XLUA_CONNECTION_METATABLE)); return 0; @@ -8114,19 +8161,147 @@ lua_connection_close (lua_State *L) static luaL_Reg lua_connection_table[] = { + { "send", lua_connection_send }, { "close", lua_connection_close }, - { "__gc", lua_connection_close }, + { "__gc", lua_connection_gc }, { NULL, NULL } }; // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +static int +lua_connection_check_fn (lua_State *L) +{ + lua_plugin_check_field (L, 1, luaL_checkstring (L, 2), LUA_TFUNCTION, true); + return 1; +} + +// We need to run it in a protected environment because of lua_getfield() +static bool +lua_connection_cb_lookup (struct lua_connection *self, const char *name, + struct error **e) +{ + lua_State *L = self->plugin->L; + lua_pushcfunction (L, lua_connection_check_fn); + hard_assert (lua_cache_get (L, self)); + lua_pushstring (L, name); + return lua_plugin_call (self->plugin, 2, 1, e); +} + +// Ideally lua_connection_cb_lookup() would return a ternary value +static bool +lua_connection_eat_nil (struct lua_connection *self) +{ + if (lua_toboolean (self->plugin->L, -1)) + return false; + lua_pop (self->plugin->L, 1); + return true; +} + +static bool +lua_connection_invoke_on_data (struct lua_connection *self, struct error **e) +{ + if (!lua_connection_cb_lookup (self, "on_data", e)) + return false; + if (lua_connection_eat_nil (self)) + return true; + + lua_pushlstring (self->plugin->L, + self->read_buffer.str, self->read_buffer.len); + return lua_plugin_call (self->plugin, 1, 0, e); +} + +static bool +lua_connection_invoke_on_eof (struct lua_connection *self, struct error **e) +{ + if (!lua_connection_cb_lookup (self, "on_eof", e)) + return false; + if (lua_connection_eat_nil (self)) + return true; + return lua_plugin_call (self->plugin, 0, 0, e); +} + +static bool +lua_connection_invoke_on_error (struct lua_connection *self, + struct error *error, struct error **e) +{ + if (!self->closing + && lua_connection_cb_lookup (self, "on_error", e) + && !lua_connection_eat_nil (self)) + { + lua_pushstring (self->plugin->L, error->message); + lua_plugin_call (self->plugin, 1, 0, e); + } + error_free (error); + return false; +} + +static bool +lua_connection_try_read (struct lua_connection *self, struct error **e) +{ + // Avoid the read call when it's obviously not going to return any data + // and would only cause unwanted invocation of callbacks + if (self->closing || self->got_eof) + return true; + + struct error *error = NULL; + enum socket_io_result read_result = + socket_io_try_read (self->socket_fd, &self->read_buffer, &error); + + // Dispatch any data that we got before an EOF or any error + if (self->read_buffer.len) + { + if (!lua_connection_invoke_on_data (self, e)) + { + if (error) + error_free (error); + return false; + } + str_reset (&self->read_buffer); + } + + if (read_result == SOCKET_IO_EOF) + { + if (!lua_connection_invoke_on_eof (self, e)) + return false; + self->got_eof = true; + } + if (read_result == SOCKET_IO_ERROR) + return lua_connection_invoke_on_error (self, error, e); + return true; +} + +static bool +lua_connection_try_write (struct lua_connection *self, struct error **e) +{ + struct error *error = NULL; + enum socket_io_result write_result = + socket_io_try_write (self->socket_fd, &self->write_buffer, &error); + + if (write_result == SOCKET_IO_ERROR) + return lua_connection_invoke_on_error (self, error, e); + return !self->closing || self->write_buffer.len; +} + static void lua_connection_on_ready (const struct pollfd *pfd, struct lua_connection *self) { - // TODO: handle the event, invoke on_data, on_close, on_error from - // our associated uservalue table as needed - lua_cache_invalidate (self->plugin->L, self); + (void) pfd; + + // Hold a reference so that it doesn't get collected on close() + hard_assert (lua_cache_get (self->plugin->L, self)); + + struct error *e = NULL; + bool keep = lua_connection_try_read (self, &e) + && lua_connection_try_write (self, &e); + if (e) + lua_plugin_log_error (self->plugin, "network I/O", e); + if (keep) + lua_connection_update_poller (self); + else + lua_connection_discard (self); + + lua_pop (self->plugin->L, 1); } static struct lua_connection * @@ -8145,6 +8320,9 @@ lua_plugin_push_connection (struct lua_plugin *plugin, int socket_fd) self->socket_event.user_data = self; poller_fd_set (&self->socket_event, POLLIN); + str_init (&self->read_buffer); + str_init (&self->write_buffer); + // Make sure the connection doesn't get garbage collected and return it lua_cache_store (L, self, -1); return self;