diff --git a/common.c b/common.c index ae92ecc..530a367 100644 --- a/common.c +++ b/common.c @@ -326,7 +326,7 @@ connector_add_target (struct connector *self, // --- SOCKS 5/4a -------------------------------------------------------------- -// Asynchronous SOCKS connector. Adds more stuff on top of the original. +// Asynchronous SOCKS connector. Adds more stuff on top of the regular one. // Note that the `username' is used differently in SOCKS 4a and 5. In the // former version, it is the username that you can get ident'ed against. @@ -396,7 +396,9 @@ struct socks_connector uint16_t bound_port; ///< Bound port at the server /// Process incoming data if there's enough of it available - bool (*on_data) (struct socks_connector *); + bool (*on_data) (struct socks_connector *, struct msg_unpacker *); + + size_t data_needed; ///< How much data the callback needs // Configuration: @@ -436,76 +438,16 @@ struct socks_connector return false; \ BLOCK_END -#define SOCKS_READ_START(n) \ - if (!socks_try_fill_read_buffer (self, (n))) \ - return false; \ - if (self->read_buffer.len < n) \ - return true; \ - struct msg_unpacker unpacker; \ - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len) - -#define SOCKS_READ_END \ - str_remove_slice (&self->read_buffer, 0, unpacker.offset) - -static bool -socks_try_fill_read_buffer (struct socks_connector *self, size_t n) -{ - ssize_t remains = (ssize_t) n - (ssize_t) self->read_buffer.len; - if (remains <= 0) - return true; - - ssize_t received; - str_ensure_space (&self->read_buffer, remains); - do - received = recv (self->socket_fd, - self->read_buffer.str + self->read_buffer.len, remains, 0); - while ((received == -1) && errno == EINTR); - - if (received == 0) - SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF"); - if (received == -1 && errno != EAGAIN) - SOCKS_FAIL ("%s: %s", "recv", strerror (errno)); - if (received > 0) - self->read_buffer.len += received; - return true; -} - -static bool -socks_try_flush_write_buffer (struct socks_connector *self) -{ - struct str *wb = &self->write_buffer; - ssize_t n_written; - - while (wb->len) - { - n_written = send (self->socket_fd, wb->str, wb->len, 0); - if (n_written >= 0) - { - str_remove_slice (wb, 0, n_written); - continue; - } - - if (errno == EAGAIN) - break; - if (errno == EINTR) - continue; - - SOCKS_FAIL ("%s: %s", "send", strerror (errno)); - return false; - } - return true; -} +#define SOCKS_DATA_CB(name) static bool name \ + (struct socks_connector *self, struct msg_unpacker *unpacker) // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -static bool -socks_4a_finish (struct socks_connector *self) +SOCKS_DATA_CB (socks_4a_finish) { - SOCKS_READ_START (8); uint8_t null, status; - hard_assert (msg_unpacker_u8 (&unpacker, &null)); - hard_assert (msg_unpacker_u8 (&unpacker, &status)); - SOCKS_READ_END; + hard_assert (msg_unpacker_u8 (unpacker, &null)); + hard_assert (msg_unpacker_u8 (unpacker, &status)); if (null != 0) SOCKS_FAIL ("protocol error"); @@ -571,79 +513,65 @@ socks_4a_start (struct socks_connector *self) } self->on_data = socks_4a_finish; + self->data_needed = 8; return true; } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -static bool -socks_5_request_port (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_port) { - SOCKS_READ_START (2); - hard_assert (msg_unpacker_u16 (&unpacker, &self->bound_port)); - SOCKS_READ_END; - + hard_assert (msg_unpacker_u16 (unpacker, &self->bound_port)); self->done = true; return false; } -static bool -socks_5_request_ipv4 (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_ipv4) { - size_t len = sizeof self->bound_address.data.ipv4; - SOCKS_READ_START (len); - memcpy (self->bound_address.data.ipv4, self->read_buffer.str, len); - str_remove_slice (&self->read_buffer, 0, len); + memcpy (self->bound_address.data.ipv4, + self->read_buffer.str, self->data_needed); self->on_data = socks_5_request_port; + self->data_needed = 2; return true; } -static bool -socks_5_request_ipv6 (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_ipv6) { - size_t len = sizeof self->bound_address.data.ipv6; - SOCKS_READ_START (len); - memcpy (self->bound_address.data.ipv6, self->read_buffer.str, len); - str_remove_slice (&self->read_buffer, 0, len); + memcpy (self->bound_address.data.ipv6, + self->read_buffer.str, self->data_needed); self->on_data = socks_5_request_port; + self->data_needed = 2; return true; } -static bool -socks_5_request_domain_data (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_domain_data) { - size_t len = self->bound_address_len; - SOCKS_READ_START (len); - self->bound_address.data.domain = xstrndup (self->read_buffer.str, len); - str_remove_slice (&self->read_buffer, 0, len); + self->bound_address.data.domain = + xstrndup (self->read_buffer.str, self->data_needed); self->on_data = socks_5_request_port; + self->data_needed = 2; return true; } -static bool -socks_5_request_domain (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_domain) { - SOCKS_READ_START (1); - hard_assert (msg_unpacker_u8 (&unpacker, &self->bound_address_len)); - SOCKS_READ_END; + hard_assert (msg_unpacker_u8 (unpacker, &self->bound_address_len)); self->on_data = socks_5_request_domain_data; + self->data_needed = self->bound_address_len; return true; } -static bool -socks_5_request_finish (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_request_finish) { - SOCKS_READ_START (4); uint8_t version, status, reserved, type; - hard_assert (msg_unpacker_u8 (&unpacker, &version)); - hard_assert (msg_unpacker_u8 (&unpacker, &status)); - hard_assert (msg_unpacker_u8 (&unpacker, &reserved)); - hard_assert (msg_unpacker_u8 (&unpacker, &type)); - SOCKS_READ_END; + hard_assert (msg_unpacker_u8 (unpacker, &version)); + hard_assert (msg_unpacker_u8 (unpacker, &status)); + hard_assert (msg_unpacker_u8 (unpacker, &reserved)); + hard_assert (msg_unpacker_u8 (unpacker, &type)); if (version != 0x05) SOCKS_FAIL ("protocol error"); @@ -665,10 +593,20 @@ socks_5_request_finish (struct socks_connector *self) switch ((self->bound_address.type = type)) { - case SOCKS_IPV4: self->on_data = socks_5_request_ipv4; return true; - case SOCKS_IPV6: self->on_data = socks_5_request_ipv6; return true; - case SOCKS_DOMAIN: self->on_data = socks_5_request_domain; return true; - default: SOCKS_FAIL ("protocol error"); + case SOCKS_IPV4: + self->on_data = socks_5_request_ipv4; + self->data_needed = sizeof self->bound_address.data.ipv4; + return true; + case SOCKS_IPV6: + self->data_needed = sizeof self->bound_address.data.ipv6; + self->on_data = socks_5_request_ipv6; + return true; + case SOCKS_DOMAIN: + self->on_data = socks_5_request_domain; + self->data_needed = 1; + return true; + default: + SOCKS_FAIL ("protocol error"); } } @@ -707,17 +645,17 @@ socks_5_request_start (struct socks_connector *self) str_append_c (wb, target->port); self->on_data = socks_5_request_finish; + self->data_needed = 4; return true; } -static bool -socks_5_userpass_finish (struct socks_connector *self) +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +SOCKS_DATA_CB (socks_5_userpass_finish) { - SOCKS_READ_START (2); uint8_t version, status; - hard_assert (msg_unpacker_u8 (&unpacker, &version)); - hard_assert (msg_unpacker_u8 (&unpacker, &status)); - SOCKS_READ_END; + hard_assert (msg_unpacker_u8 (unpacker, &version)); + hard_assert (msg_unpacker_u8 (unpacker, &status)); if (version != 0x01) SOCKS_FAIL ("protocol error"); @@ -745,17 +683,15 @@ socks_5_userpass_start (struct socks_connector *self) str_append_data (wb, self->password, plen); self->on_data = socks_5_userpass_finish; + self->data_needed = 2; return true; } -static bool -socks_5_auth_finish (struct socks_connector *self) +SOCKS_DATA_CB (socks_5_auth_finish) { - SOCKS_READ_START (2); uint8_t version, method; - hard_assert (msg_unpacker_u8 (&unpacker, &version)); - hard_assert (msg_unpacker_u8 (&unpacker, &method)); - SOCKS_READ_END; + hard_assert (msg_unpacker_u8 (unpacker, &version)); + hard_assert (msg_unpacker_u8 (unpacker, &method)); if (version != 0x05) SOCKS_FAIL ("protocol error"); @@ -791,12 +727,72 @@ socks_5_auth_start (struct socks_connector *self) str_append_c (wb, 0x02); // username/password self->on_data = socks_5_auth_finish; + self->data_needed = 2; return true; } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -static void socks_connector_step (struct socks_connector *self); +static void socks_connector_start (struct socks_connector *self); + +static void +socks_connector_fail (struct socks_connector *self) +{ + poller_fd_reset (&self->socket_event); + self->on_failure (self->user_data); +} + +static bool +socks_connector_step_iterators (struct socks_connector *self) +{ + // At the lowest level we iterate over all addresses for the SOCKS server; + // this is done automatically by the connector + + // Then we iterate over available protocols + if (++self->protocol_iter != SOCKS_MAX) + return true; + + // At the highest level we iterate over possible targets + self->protocol_iter = 0; + if (self->targets_iter && (self->targets_iter = self->targets_iter->next)) + return true; + + return false; +} + +static void +socks_connector_step (struct socks_connector *self) +{ + if (self->socket_fd != -1) + { + poller_fd_reset (&self->socket_event); + xclose (self->socket_fd); + self->socket_fd = -1; + } + + if (self->connector) + { + connector_free (self->connector); + free (self->connector); + self->connector = NULL; + } + + if (socks_connector_step_iterators (self)) + socks_connector_start (self); + else + socks_connector_fail (self); +} + +static void +socks_connector_on_timeout (struct socks_connector *self) +{ + if (self->on_error) + self->on_error (self->user_data, "timeout"); + + socks_connector_fail (self); +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - static void socks_connector_on_connected (void *user_data, int socket_fd) @@ -814,7 +810,7 @@ socks_connector_on_connected (void *user_data, int socket_fd) || (self->protocol_iter == SOCKS_4A && socks_4a_start (self))) return; - self->on_failure (self->user_data); + socks_connector_fail (self); } static void @@ -870,45 +866,71 @@ socks_connector_start (struct socks_connector *self) self->done = false; } -static void -socks_connector_step (struct socks_connector *self) +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static bool +socks_try_fill_read_buffer (struct socks_connector *self, size_t n) { - // Close the socket if needed - if (self->socket_fd != -1) + ssize_t remains = (ssize_t) n - (ssize_t) self->read_buffer.len; + if (remains <= 0) + return true; + + ssize_t received; + str_ensure_space (&self->read_buffer, remains); + do + received = recv (self->socket_fd, + self->read_buffer.str + self->read_buffer.len, remains, 0); + while ((received == -1) && errno == EINTR); + + if (received == 0) + SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF"); + if (received == -1 && errno != EAGAIN) + SOCKS_FAIL ("%s: %s", "recv", strerror (errno)); + if (received > 0) + self->read_buffer.len += received; + return true; +} + +static bool +socks_call_on_data (struct socks_connector *self) +{ + size_t to_consume = self->data_needed; + if (!socks_try_fill_read_buffer (self, to_consume)) + return false; + if (self->read_buffer.len < to_consume) + return true; + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + bool result = self->on_data (self, &unpacker); + str_remove_slice (&self->read_buffer, 0, to_consume); + return result; +} + +static bool +socks_try_flush_write_buffer (struct socks_connector *self) +{ + struct str *wb = &self->write_buffer; + ssize_t n_written; + + while (wb->len) { - poller_fd_reset (&self->socket_event); - xclose (self->socket_fd); - self->socket_fd = -1; + n_written = send (self->socket_fd, wb->str, wb->len, 0); + if (n_written >= 0) + { + str_remove_slice (wb, 0, n_written); + continue; + } + + if (errno == EAGAIN) + break; + if (errno == EINTR) + continue; + + SOCKS_FAIL ("%s: %s", "send", strerror (errno)); + return false; } - - // Destroy current connector if needed - if (self->connector) - { - connector_free (self->connector); - free (self->connector); - self->connector = NULL; - } - - // At the lowest level we iterate over all addresses for the SOCKS server; - // this is done automatically by the connector - - // Then we iterate over available protocols - if (++self->protocol_iter != SOCKS_MAX) - { - socks_connector_start (self); - return; - } - - // At the highest level we iterate over possible targets - self->protocol_iter = 0; - if (self->targets_iter && (self->targets_iter = self->targets_iter->next)) - { - socks_connector_start (self); - return; - } - - // FIXME: we need to cancel all events - self->on_failure (self->user_data); + return true; } static void @@ -917,7 +939,7 @@ socks_connector_on_ready { (void) pfd; - if (self->on_data (self) && socks_try_flush_write_buffer (self)) + if (socks_call_on_data (self) && socks_try_flush_write_buffer (self)) { poller_fd_set (&self->socket_event, self->write_buffer.len ? (POLLIN | POLLOUT) : POLLIN); @@ -935,16 +957,6 @@ socks_connector_on_ready socks_connector_step (self); } -static void -socks_connector_on_timeout (struct socks_connector *self) -{ - if (self->on_error) - self->on_error (self->user_data, "timeout"); - - // FIXME: we need to cancel all events - self->on_failure (self->user_data); -} - // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - static void