diff --git a/common.c b/common.c index 21842ad..cae9248 100644 --- a/common.c +++ b/common.c @@ -229,7 +229,7 @@ connector_step (struct connector *self) self->on_connected (self->user_data, fd); return; } - else if (errno != EINPROGRESS) + if (errno != EINPROGRESS) { connector_notify_error (self, strerror (errno)); xclose (fd); @@ -324,15 +324,14 @@ connector_add_target (struct connector *self, return true; } -// --- SOCKS 5/4a (blocking implementation) ------------------------------------ +// --- SOCKS 5/4a -------------------------------------------------------------- -// These are awkward protocols. Note that the `username' is used differently -// in SOCKS 4a and 5. In the former version, it is the username that you can -// get ident'ed against. In the latter version, it forms a pair with the -// password field and doesn't need to be an actual user on your machine. +// Asynchronous SOCKS connector. Adds more stuff on top of the original. -// TODO: make a non-blocking poller-based version of this; -// either use c-ares or (even better) start another thread to do resolution +// Note that the `username' is used differently in SOCKS 4a and 5. In the +// former version, it is the username that you can get ident'ed against. +// In the latter version, it forms a pair with the password field and doesn't +// need to be an actual user on your machine. struct socks_addr { @@ -346,125 +345,183 @@ struct socks_addr union { uint8_t ipv4[4]; ///< IPv4 address, network octet order - const char *domain; ///< Domain name + char *domain; ///< Domain name uint8_t ipv6[16]; ///< IPv6 address, network octet order } data; ///< The address itself }; -struct socks_data +static void +socks_addr_free (struct socks_addr *self) { - struct socks_addr address; ///< Target address - uint16_t port; ///< Target port - const char *username; ///< Authentication username - const char *password; ///< Authentication password + if (self->type == SOCKS_DOMAIN) + free (self->data.domain); +} - struct socks_addr bound_address; ///< Bound address at the server - uint16_t bound_port; ///< Bound port at the server +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +struct socks_target +{ + LIST_HEADER (struct socks_target) + + struct socks_addr address; ///< Target address + uint16_t port; ///< Target service port }; -static bool -socks_get_socket (struct addrinfo *addresses, int *fd, struct error **e) +enum socks_protocol { - int sockfd; - for (; addresses; addresses = addresses->ai_next) - { - sockfd = socket (addresses->ai_family, - addresses->ai_socktype, addresses->ai_protocol); - if (sockfd == -1) - continue; - set_cloexec (sockfd); + SOCKS_5, ///< SOCKS5 + SOCKS_4A, ///< SOCKS4A + SOCKS_MAX ///< End of protocol +}; - int yes = 1; - soft_assert (setsockopt (sockfd, SOL_SOCKET, SO_KEEPALIVE, - &yes, sizeof yes) != -1); +struct socks_connector +{ + struct connector *connector; ///< Proxy server iterator (effectively) + enum socks_protocol protocol_iter; ///< Protocol iterator + struct socks_target *targets_iter; ///< Targets iterator - if (!connect (sockfd, addresses->ai_addr, addresses->ai_addrlen)) - break; - xclose (sockfd); - } - if (!addresses) - { - error_set (e, "couldn't connect to the SOCKS server"); - return false; - } - *fd = sockfd; - return true; -} + // Negotiation: + + int socket_fd; ///< Current socket file descriptor + struct poller_fd socket_event; ///< Socket can be read from/written to + struct str read_buffer; ///< Read buffer + struct str write_buffer; ///< Write buffer + + struct poller_timer timeout; ///< Timeout timer + + uint8_t bound_address_len; ///< Length of domain name + struct socks_addr bound_address; ///< Bound address at the server + uint16_t bound_port; ///< Bound port at the server + + /// Process incoming data if there's enough of it available + bool (*on_data) (struct socks_connector *); + + // Configuration: + + const char *hostname; ///< SOCKS server hostname + const char *service; ///< SOCKS server service name or port + + const char *username; ///< Username for authentication + const char *password; ///< Password for authentication + + struct socks_target *targets; ///< Targets + struct socks_target *targets_tail; ///< Tail of targets + + void *user_data; ///< User data for callbacks + + // You may destroy the connector object in these two main callbacks: + + /// Connection has been successfully established + void (*on_connected) (void *user_data, int socket); + /// Failed to establish a connection to either target + void (*on_failure) (void *user_data); + + // Optional: + + /// Connecting to a new address + void (*on_connecting) (void *user_data, + const char *address, const char *via, const char *version); + /// Connecting to the last address has failed + void (*on_error) (void *user_data, const char *error); +}; #define SOCKS_FAIL(...) \ BLOCK_START \ - error_set (e, __VA_ARGS__); \ - goto fail; \ + char *error = xstrdup_printf (__VA_ARGS__); \ + if (self->on_error) \ + self->on_error (self->user_data, error); \ + free (error); \ + return false; \ BLOCK_END -#define SOCKS_RECV(buf, len) \ + +// FIXME: we need to cancel all events +#define SOCKS_DONE() \ BLOCK_START \ - if ((n = recv (sockfd, (buf), (len), 0)) == -1) \ - SOCKS_FAIL ("%s: %s", "recv", strerror (errno)); \ - if (n != (len)) \ - SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF"); \ + int fd = self->socket_fd; \ + set_blocking (fd, true); \ + self->socket_fd = -1; \ + self->on_connected (self->user_data, fd); \ + return true; \ BLOCK_END +#define SOCKS_NEED_DATA(n) \ + if (!socks_try_fill_read_buffer (self, (n))) \ + return false; \ + if (self->read_buffer.len < n) \ + return true + static bool -socks_4a_connect (struct addrinfo *addresses, struct socks_data *data, - int *fd, struct error **e) +socks_try_fill_read_buffer (struct socks_connector *self, size_t n) { - int sockfd; - if (!socks_get_socket (addresses, &sockfd, e)) - return false; + ssize_t remains = (ssize_t) n - (ssize_t) self->read_buffer.len; + if (remains <= 0) + return true; - const void *dest_ipv4 = "\x00\x00\x00\x01"; - const char *dest_domain = NULL; + ssize_t received; + str_ensure_space (&self->read_buffer, remains); + do + received = recv (self->socket_fd, + self->read_buffer.str + self->read_buffer.len, remains, 0); + while ((received == -1) && errno == EINTR); - char buf[INET6_ADDRSTRLEN]; - switch (data->address.type) + if (received == 0) + SOCKS_FAIL ("%s: %s", "protocol error", "unexpected EOF"); + if (received == -1 && errno != EAGAIN) + SOCKS_FAIL ("%s: %s", "recv", strerror (errno)); + if (received > 0) + self->read_buffer.len += received; + return true; +} + +static bool +socks_try_flush_write_buffer (struct socks_connector *self) +{ + struct str *wb = &self->write_buffer; + ssize_t n_written; + + while (wb->len) { - case SOCKS_IPV4: - dest_ipv4 = data->address.data.ipv4; - break; - case SOCKS_IPV6: - // About the best thing we can do, not sure if it works anywhere at all - if (!inet_ntop (AF_INET6, &data->address.data.ipv6, buf, sizeof buf)) - SOCKS_FAIL ("%s: %s", "inet_ntop", strerror (errno)); - dest_domain = buf; - break; - case SOCKS_DOMAIN: - dest_domain = data->address.data.domain; - } + n_written = send (self->socket_fd, wb->str, wb->len, 0); + if (n_written >= 0) + { + str_remove_slice (wb, 0, n_written); + continue; + } - struct str req; - str_init (&req); - str_append_c (&req, 4); // version - str_append_c (&req, 1); // connect + if (errno == EAGAIN) + break; + if (errno == EINTR) + continue; - str_append_c (&req, data->port >> 8); // higher bits of port - str_append_c (&req, data->port); // lower bits of port - str_append_data (&req, dest_ipv4, 4); // destination address - - if (data->username) - str_append (&req, data->username); - str_append_c (&req, '\0'); - - if (dest_domain) - { - str_append (&req, dest_domain); - str_append_c (&req, '\0'); - } - - ssize_t n = send (sockfd, req.str, req.len, 0); - str_free (&req); - if (n == -1) SOCKS_FAIL ("%s: %s", "send", strerror (errno)); + return false; + } + return true; +} - uint8_t resp[8]; - SOCKS_RECV (resp, sizeof resp); - if (resp[0] != 0) +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static bool +socks_4a_finish (struct socks_connector *self) +{ + SOCKS_NEED_DATA (8); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + + uint8_t null, status; + hard_assert (msg_unpacker_u8 (&unpacker, &null)); + hard_assert (msg_unpacker_u8 (&unpacker, &status)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + if (null != 0) SOCKS_FAIL ("protocol error"); - switch (resp[1]) + switch (status) { case 90: - break; + SOCKS_DONE (); case 91: SOCKS_FAIL ("request rejected or failed"); case 92: @@ -476,141 +533,138 @@ socks_4a_connect (struct addrinfo *addresses, struct socks_data *data, default: SOCKS_FAIL ("protocol error"); } - - *fd = sockfd; - return true; - -fail: - xclose (sockfd); - return false; -} - -#undef SOCKS_FAIL -#define SOCKS_FAIL(...) \ - BLOCK_START \ - error_set (e, __VA_ARGS__); \ - return false; \ - BLOCK_END - -static bool -socks_5_userpass_auth (int sockfd, struct socks_data *data, struct error **e) -{ - size_t ulen = strlen (data->username); - if (ulen > 255) - ulen = 255; - - size_t plen = strlen (data->password); - if (plen > 255) - plen = 255; - - uint8_t req[3 + ulen + plen], *p = req; - *p++ = 0x01; // version - *p++ = ulen; // username length - memcpy (p, data->username, ulen); - p += ulen; - *p++ = plen; // password length - memcpy (p, data->password, plen); - p += plen; - - ssize_t n = send (sockfd, req, p - req, 0); - if (n == -1) - SOCKS_FAIL ("%s: %s", "send", strerror (errno)); - - uint8_t resp[2]; - SOCKS_RECV (resp, sizeof resp); - if (resp[0] != 0x01) - SOCKS_FAIL ("protocol error"); - if (resp[1] != 0x00) - SOCKS_FAIL ("authentication failure"); - return true; } static bool -socks_5_auth (int sockfd, struct socks_data *data, struct error **e) +socks_4a_start (struct socks_connector *self) { - bool can_auth = data->username && data->password; + struct socks_target *target = self->targets_iter; + const void *dest_ipv4 = "\x00\x00\x00\x01"; + const char *dest_domain = NULL; - uint8_t hello[4]; - hello[0] = 0x05; // version - hello[1] = 1 + can_auth; // number of authentication methods - hello[2] = 0x00; // no authentication required - hello[3] = 0x02; // username/password - - ssize_t n = send (sockfd, hello, 3 + can_auth, 0); - if (n == -1) - SOCKS_FAIL ("%s: %s", "send", strerror (errno)); - - uint8_t resp[2]; - SOCKS_RECV (resp, sizeof resp); - if (resp[0] != 0x05) - SOCKS_FAIL ("protocol error"); - - switch (resp[1]) - { - case 0x02: - if (!can_auth) - SOCKS_FAIL ("protocol error"); - if (!socks_5_userpass_auth (sockfd, data, e)) - return false; - case 0x00: - break; - case 0xFF: - SOCKS_FAIL ("no acceptable authentication methods"); - default: - SOCKS_FAIL ("protocol error"); - } - return true; -} - -static bool -socks_5_send_req (int sockfd, struct socks_data *data, struct error **e) -{ - uint8_t req[4 + 256 + 2], *p = req; - *p++ = 0x05; // version - *p++ = 0x01; // connect - *p++ = 0x00; // reserved - *p++ = data->address.type; - - switch (data->address.type) + char buf[INET6_ADDRSTRLEN]; + switch (target->address.type) { case SOCKS_IPV4: - memcpy (p, data->address.data.ipv4, sizeof data->address.data.ipv4); - p += sizeof data->address.data.ipv4; + dest_ipv4 = target->address.data.ipv4; + break; + case SOCKS_IPV6: + // About the best thing we can do, not sure if it works anywhere at all + if (!inet_ntop (AF_INET6, &target->address.data.ipv6, buf, sizeof buf)) + SOCKS_FAIL ("%s: %s", "inet_ntop", strerror (errno)); + dest_domain = buf; break; case SOCKS_DOMAIN: + dest_domain = target->address.data.domain; + } + + struct str *wb = &self->write_buffer; + str_init (wb); + str_append_c (wb, 4); // version + str_append_c (wb, 1); // connect + + str_append_c (wb, target->port >> 8); // higher bits of port + str_append_c (wb, target->port); // lower bits of port + str_append_data (wb, dest_ipv4, 4); // destination address + + if (self->username) + str_append (wb, self->username); + str_append_c (wb, '\0'); + + if (dest_domain) { - size_t dlen = strlen (data->address.data.domain); - if (dlen > 255) - dlen = 255; - - *p++ = dlen; - memcpy (p, data->address.data.domain, dlen); - p += dlen; - break; + str_append (wb, dest_domain); + str_append_c (wb, '\0'); } - case SOCKS_IPV6: - memcpy (p, data->address.data.ipv6, sizeof data->address.data.ipv6); - p += sizeof data->address.data.ipv6; - break; - } - *p++ = data->port >> 8; - *p++ = data->port; - if (send (sockfd, req, p - req, 0) == -1) - SOCKS_FAIL ("%s: %s", "send", strerror (errno)); + self->on_data = socks_4a_finish; + return true; +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static bool +socks_5_request_port (struct socks_connector *self) +{ + SOCKS_NEED_DATA (2); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + hard_assert (msg_unpacker_u16 (&unpacker, &self->bound_port)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + SOCKS_DONE (); +} + +static bool +socks_5_request_ipv4 (struct socks_connector *self) +{ + size_t len = sizeof self->bound_address.data.ipv4; + SOCKS_NEED_DATA (len); + memcpy (self->bound_address.data.ipv4, self->read_buffer.str, len); + str_remove_slice (&self->read_buffer, 0, len); + + self->on_data = socks_5_request_port; return true; } static bool -socks_5_process_resp (int sockfd, struct socks_data *data, struct error **e) +socks_5_request_ipv6 (struct socks_connector *self) { - uint8_t resp_header[4]; - ssize_t n; - SOCKS_RECV (resp_header, sizeof resp_header); - if (resp_header[0] != 0x05) + size_t len = sizeof self->bound_address.data.ipv6; + SOCKS_NEED_DATA (len); + memcpy (self->bound_address.data.ipv6, self->read_buffer.str, len); + str_remove_slice (&self->read_buffer, 0, len); + + self->on_data = socks_5_request_port; + return true; +} + +static bool +socks_5_request_domain_data (struct socks_connector *self) +{ + size_t len = self->bound_address_len; + SOCKS_NEED_DATA (len); + self->bound_address.data.domain = xstrndup (self->read_buffer.str, len); + str_remove_slice (&self->read_buffer, 0, len); + + self->on_data = socks_5_request_port; + return true; +} + +static bool +socks_5_request_domain (struct socks_connector *self) +{ + SOCKS_NEED_DATA (1); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + hard_assert (msg_unpacker_u8 (&unpacker, &self->bound_address_len)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + self->on_data = socks_5_request_domain_data; + return true; +} + +static bool +socks_5_request_finish (struct socks_connector *self) +{ + SOCKS_NEED_DATA (4); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + + uint8_t version, status, reserved, type; + hard_assert (msg_unpacker_u8 (&unpacker, &version)); + hard_assert (msg_unpacker_u8 (&unpacker, &status)); + hard_assert (msg_unpacker_u8 (&unpacker, &reserved)); + hard_assert (msg_unpacker_u8 (&unpacker, &type)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + if (version != 0x05) SOCKS_FAIL ("protocol error"); - switch (resp_header[1]) + switch (status) { case 0x00: break; @@ -625,109 +679,368 @@ socks_5_process_resp (int sockfd, struct socks_data *data, struct error **e) default: SOCKS_FAIL ("protocol error"); } - switch ((data->bound_address.type = resp_header[3])) + switch ((self->bound_address.type = type)) + { + case SOCKS_IPV4: self->on_data = socks_5_request_ipv4; return true; + case SOCKS_IPV6: self->on_data = socks_5_request_ipv6; return true; + case SOCKS_DOMAIN: self->on_data = socks_5_request_domain; return true; + default: SOCKS_FAIL ("protocol error"); + } +} + +static bool +socks_5_request_start (struct socks_connector *self) +{ + struct socks_target *target = self->targets_iter; + struct str *wb = &self->write_buffer; + str_append_c (wb, 0x05); // version + str_append_c (wb, 0x01); // connect + str_append_c (wb, 0x00); // reserved + str_append_c (wb, target->address.type); + + switch (target->address.type) { case SOCKS_IPV4: - SOCKS_RECV (data->bound_address.data.ipv4, - sizeof data->bound_address.data.ipv4); - break; - case SOCKS_IPV6: - SOCKS_RECV (data->bound_address.data.ipv6, - sizeof data->bound_address.data.ipv6); + str_append_data (wb, + target->address.data.ipv4, sizeof target->address.data.ipv4); break; case SOCKS_DOMAIN: { - uint8_t len; - SOCKS_RECV (&len, sizeof len); + size_t dlen = strlen (target->address.data.domain); + if (dlen > 255) + dlen = 255; - char domain[len + 1]; - SOCKS_RECV (domain, len); - domain[len] = '\0'; - - data->bound_address.data.domain = xstrdup (domain); + str_append_c (wb, dlen); + str_append_data (wb, target->address.data.domain, dlen); break; } + case SOCKS_IPV6: + str_append_data (wb, + target->address.data.ipv6, sizeof target->address.data.ipv6); + break; + } + str_append_c (wb, target->port >> 8); + str_append_c (wb, target->port); + + self->on_data = socks_5_request_finish; + return true; +} + +static bool +socks_5_userpass_finish (struct socks_connector *self) +{ + SOCKS_NEED_DATA (2); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + + uint8_t version, status; + hard_assert (msg_unpacker_u8 (&unpacker, &version)); + hard_assert (msg_unpacker_u8 (&unpacker, &status)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + if (version != 0x01) + SOCKS_FAIL ("protocol error"); + if (status != 0x00) + SOCKS_FAIL ("authentication failure"); + return socks_5_request_start (self); +} + +static bool +socks_5_userpass_start (struct socks_connector *self) +{ + size_t ulen = strlen (self->username); + if (ulen > 255) + ulen = 255; + + size_t plen = strlen (self->password); + if (plen > 255) + plen = 255; + + struct str *wb = &self->write_buffer; + str_append_c (wb, 0x01); // version + str_append_c (wb, ulen); // username length + str_append_data (wb, self->username, ulen); + str_append_c (wb, plen); // password length + str_append_data (wb, self->password, plen); + + self->on_data = socks_5_userpass_finish; + return true; +} + +static bool +socks_5_auth_finish (struct socks_connector *self) +{ + SOCKS_NEED_DATA (2); + + struct msg_unpacker unpacker; + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + + uint8_t version, method; + hard_assert (msg_unpacker_u8 (&unpacker, &version)); + hard_assert (msg_unpacker_u8 (&unpacker, &method)); + str_remove_slice (&self->read_buffer, 0, unpacker.offset); + + if (version != 0x05) + SOCKS_FAIL ("protocol error"); + + bool can_auth = self->username && self->password; + + switch (method) + { + case 0x02: + if (!can_auth) + SOCKS_FAIL ("protocol error"); + + return socks_5_userpass_start (self); + case 0x00: + return socks_5_request_start (self); + case 0xFF: + SOCKS_FAIL ("no acceptable authentication methods"); default: SOCKS_FAIL ("protocol error"); } - - uint16_t port; - SOCKS_RECV (&port, sizeof port); - data->bound_port = ntohs (port); - return true; } -#undef SOCKS_FAIL -#undef SOCKS_RECV - static bool -socks_5_connect (struct addrinfo *addresses, struct socks_data *data, - int *fd, struct error **e) +socks_5_auth_start (struct socks_connector *self) { - int sockfd; - if (!socks_get_socket (addresses, &sockfd, e)) - return false; + bool can_auth = self->username && self->password; - if (!socks_5_auth (sockfd, data, e) - || !socks_5_send_req (sockfd, data, e) - || !socks_5_process_resp (sockfd, data, e)) - { - xclose (sockfd); - return false; - } + struct str *wb = &self->write_buffer; + str_append_c (wb, 0x05); // version + str_append_c (wb, 1 + can_auth); // number of authentication methods + str_append_c (wb, 0x00); // no authentication required + str_append_c (wb, 0x02); // username/password - *fd = sockfd; + self->on_data = socks_5_auth_finish; return true; } -static int -socks_connect (const char *socks_host, const char *socks_port, - const char *host, const char *port, - const char *username, const char *password, struct error **e) -{ - int result = -1; - struct addrinfo gai_hints, *gai_result; - memset (&gai_hints, 0, sizeof gai_hints); - gai_hints.ai_socktype = SOCK_STREAM; +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - unsigned long port_no; +static void socks_connector_step (struct socks_connector *self); + +static void +socks_connector_on_connected (void *user_data, int socket_fd) +{ + set_blocking (socket_fd, false); + + struct socks_connector *self = user_data; + self->socket_fd = socket_fd; + self->socket_event.fd = socket_fd; + poller_fd_set (&self->socket_event, POLLIN | POLLOUT); + str_reset (&self->read_buffer); + str_reset (&self->write_buffer); + + if ((self->protocol_iter == SOCKS_5 && socks_5_auth_start (self)) + || (self->protocol_iter == SOCKS_4A && socks_4a_start (self))) + return; + + self->on_failure (self->user_data); +} + +static void +socks_connector_on_failure (void *user_data) +{ + struct socks_connector *self = user_data; + // TODO: skip SOCKS server on connection failure + socks_connector_step (self); +} + +static void +socks_connector_on_connecting (void *user_data, const char *via) +{ + struct socks_connector *self = user_data; + if (!self->on_connecting) + return; + + // TODO: reconstruct the address from the current target iterator, + // or just store it in unprocessed form + char *address = format_host_port_pair ("", ""); + self->on_connecting (self->user_data, address, via, + self->protocol_iter ? "SOCKS4A" : "SOCKS5"); + free (address); +} + +static void +socks_connector_on_error (void *user_data, const char *error) +{ + struct socks_connector *self = user_data; + // TODO: skip protocol on protocol failure + self->on_error (self->user_data, error); +} + +static void +socks_connector_start (struct socks_connector *self) +{ + struct connector *connector = + self->connector = xcalloc (1, sizeof *connector); + connector_init (connector, self->socket_event.poller); + + connector->user_data = self; + connector->on_connected = socks_connector_on_connected; + connector->on_connecting = socks_connector_on_connecting; + connector->on_error = socks_connector_on_error; + connector->on_failure = socks_connector_on_failure; + + // TODO: let's rather call on_error and on_failure instead on error + hard_assert (connector_add_target (connector, + self->hostname, self->service, NULL)); + + connector_step (connector); + poller_timer_set (&self->timeout, 60 * 1000); +} + +static void +socks_connector_step (struct socks_connector *self) +{ + // Destroy current connector if needed + if (self->connector) + { + connector_free (self->connector); + free (self->connector); + self->connector = NULL; + } + + // At the lowest level we iterate over all addresses for the SOCKS server; + // this is done automatically by the connector + + // Then we iterate over available protocols + if (++self->protocol_iter != SOCKS_MAX) + { + socks_connector_start (self); + return; + } + + // At the highest level we iterate over possible targets + self->protocol_iter = 0; + if (self->targets_iter && (self->targets_iter = self->targets_iter->next)) + { + socks_connector_start (self); + return; + } + + // FIXME: we need to cancel all events + self->on_failure (self->user_data); +} + +static void +socks_connector_on_ready + (const struct pollfd *pfd, struct socks_connector *self) +{ + (void) pfd; + + if (!self->on_data (self) || !socks_try_flush_write_buffer (self)) + { + // We've failed this target, let's try to move on + // FIXME: we need to cancel all events + socks_connector_step (self); + } + // If we successfully establish the connection, then the FD is reset to -1 + else if (self->socket_fd != -1) + { + poller_fd_set (&self->socket_event, + self->write_buffer.len ? (POLLIN | POLLOUT) : POLLIN); + } +} + +static void +socks_connector_on_timeout (struct socks_connector *self) +{ + if (self->on_error) + self->on_error (self->user_data, "timeout"); + + // FIXME: we need to cancel all events + self->on_failure (self->user_data); +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static void +socks_connector_init (struct socks_connector *self, struct poller *poller) +{ + memset (self, 0, sizeof *self); + + poller_fd_init (&self->socket_event, poller, (self->socket_fd = -1)); + self->socket_event.dispatcher = (poller_fd_fn) socks_connector_on_ready; + self->socket_event.user_data = self; + + poller_timer_init (&self->timeout, poller); + self->timeout.dispatcher = (poller_timer_fn) socks_connector_on_timeout; + self->timeout.user_data = self; + + str_init (&self->read_buffer); + str_init (&self->write_buffer); +} + +static void +socks_connector_free (struct socks_connector *self) +{ + if (self->connector) + { + connector_free (self->connector); + free (self->connector); + } + + poller_fd_reset (&self->socket_event); + poller_timer_reset (&self->timeout); + + if (self->socket_fd != -1) + xclose (self->socket_fd); + + str_free (&self->read_buffer); + str_free (&self->write_buffer); + + LIST_FOR_EACH (struct socks_target, iter, self->targets) + { + socks_addr_free (&iter->address); + free (iter); + } + + socks_addr_free (&self->bound_address); +} + +static bool +socks_connector_add_target (struct socks_connector *self, + const char *host, const char *service, struct error **e) +{ + unsigned long port; const struct servent *serv; - if ((serv = getservbyname (port, "tcp"))) - port_no = (uint16_t) ntohs (serv->s_port); - else if (!xstrtoul (&port_no, port, 10) || !port_no || port_no > UINT16_MAX) + if ((serv = getservbyname (service, "tcp"))) + port = (uint16_t) ntohs (serv->s_port); + else if (!xstrtoul (&port, service, 10) || !port || port > UINT16_MAX) { error_set (e, "invalid port number"); - goto fail; + return false; } - int err = getaddrinfo (socks_host, socks_port, &gai_hints, &gai_result); - if (err) - { - error_set (e, "%s: %s", "getaddrinfo", gai_strerror (err)); - goto fail; - } - - struct socks_data data = - { .username = username, .password = password, .port = port_no }; - - if (inet_pton (AF_INET, host, &data.address.data.ipv4) == 1) - data.address.type = SOCKS_IPV4; - else if (inet_pton (AF_INET6, host, &data.address.data.ipv6) == 1) - data.address.type = SOCKS_IPV6; + struct socks_target *target = xcalloc (1, sizeof *target); + if (inet_pton (AF_INET, host, &target->address.data.ipv4) == 1) + target->address.type = SOCKS_IPV4; + else if (inet_pton (AF_INET6, host, &target->address.data.ipv6) == 1) + target->address.type = SOCKS_IPV6; else { - data.address.type = SOCKS_DOMAIN; - data.address.data.domain = host; + target->address.type = SOCKS_DOMAIN; + target->address.data.domain = xstrdup (host); } - if (!socks_5_connect (gai_result, &data, &result, NULL)) - socks_4a_connect (gai_result, &data, &result, e); + target->port = port; + LIST_APPEND_WITH_TAIL (self->targets, self->targets_tail, target); + return true; +} - if (data.bound_address.type == SOCKS_DOMAIN) - free ((char *) data.bound_address.data.domain); - freeaddrinfo (gai_result); -fail: - return result; +static void +socks_connector_run (struct socks_connector *self) +{ + // XXX: do we need some better error checking in here? + hard_assert (self->hostname); + hard_assert (self->targets); + + self->targets_iter = self->targets; + self->protocol_iter = 0; + socks_connector_start (self); } // --- CTCP decoding ----------------------------------------------------------- diff --git a/degesch.c b/degesch.c index d1d2a63..d1fc783 100644 --- a/degesch.c +++ b/degesch.c @@ -1116,6 +1116,7 @@ struct server enum server_state state; ///< Connection state struct connector *connector; ///< Connection establisher + struct socks_connector *socks_conn; ///< SOCKS connection establisher unsigned reconnect_attempt; ///< Number of reconnect attempt bool manual_disconnect; ///< Don't reconnect after disconnect @@ -1265,6 +1266,11 @@ server_free (struct server *self) connector_free (self->connector); free (self->connector); } + if (self->socks_conn) + { + socks_connector_free (self->socks_conn); + free (self->socks_conn); + } if (self->transport && self->transport->cleanup) @@ -3591,10 +3597,16 @@ irc_shutdown (struct server *s) static void irc_destroy_connector (struct server *s) { - connector_free (s->connector); + if (s->connector) + connector_free (s->connector); free (s->connector); s->connector = NULL; + if (s->socks_conn) + socks_connector_free (s->socks_conn); + free (s->socks_conn); + s->socks_conn = NULL; + // Not connecting anymore s->state = IRC_DISCONNECTED; } @@ -4352,6 +4364,28 @@ irc_finish_connection (struct server *s, int socket) refresh_prompt (s->ctx); } +static void +irc_split_host_port (char *s, char **host, char **port) +{ + char *colon = strrchr (s, ':'); + if (colon) + { + *colon = '\0'; + *port = ++colon; + } + else + *port = "6667"; + + // Unwrap IPv6 addresses in format_host_port_pair() format + size_t host_end = strlen (s) - 1; + if (*s == '[' && s[host_end] == ']') + s++[host_end] = '\0'; + + *host = s; +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + static void irc_on_connector_connecting (void *user_data, const char *address) { @@ -4382,32 +4416,13 @@ irc_on_connector_connected (void *user_data, int socket) irc_finish_connection (s, socket); } -static void -irc_split_host_port (char *s, char **host, char **port) -{ - char *colon = strrchr (s, ':'); - if (colon) - { - *colon = '\0'; - *port = ++colon; - } - else - *port = "6667"; - - // Unwrap IPv6 addresses in format_host_port_pair() format - size_t host_end = strlen (s) - 1; - if (*s == '[' && s[host_end] == ']') - s++[host_end] = '\0'; - - *host = s; -} - static bool irc_setup_connector (struct server *s, const struct str_vector *addresses, struct error **e) { struct connector *connector = xmalloc (sizeof *connector); connector_init (connector, &s->ctx->poller); + s->connector = connector; connector->user_data = s; connector->on_connecting = irc_on_connector_connecting; @@ -4415,69 +4430,76 @@ irc_setup_connector (struct server *s, connector->on_connected = irc_on_connector_connected; connector->on_failure = irc_on_connector_failure; - s->state = IRC_CONNECTING; - s->connector = connector; - for (size_t i = 0; i < addresses->len; i++) { char *host, *port; irc_split_host_port (addresses->vector[i], &host, &port); if (!connector_add_target (connector, host, port, e)) - { - irc_destroy_connector (s); return false; - } } connector_step (connector); return true; } +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// TODO: see if we can further merge code for the two connectors, for example +// by making SOCKS 4A and 5 mere plugins for the connector, or by using +// a virtual interface common to them both (seems more likely) + +static void +irc_on_socks_connecting (void *user_data, + const char *address, const char *via, const char *version) +{ + struct server *s = user_data; + log_server_status (s, s->buffer, + "Connecting to #s via #s (#s)...", address, via, version); +} + static bool -irc_initiate_connect_socks (struct server *s, +irc_setup_connector_socks (struct server *s, const struct str_vector *addresses, struct error **e) { const char *socks_host = get_config_string (s->config, "socks_host"); int64_t socks_port_int = get_config_integer (s->config, "socks_port"); - const char *socks_username = - get_config_string (s->config, "socks_username"); - const char *socks_password = - get_config_string (s->config, "socks_password"); - if (!socks_host) return false; - // FIXME: we only try the first address (still better than nothing) - char *irc_host, *irc_port; - irc_split_host_port (addresses->vector[0], &irc_host, &irc_port); + struct socks_connector *connector = xmalloc (sizeof *connector); + socks_connector_init (connector, &s->ctx->poller); + s->socks_conn = connector; - char *socks_port = xstrdup_printf ("%" PRIi64, socks_port_int); + // FIXME: the SOCKS connector may outlive these values + connector->hostname = socks_host; + // FIXME: memory leak + connector->service = xstrdup_printf ("%" PRIi64, socks_port_int); + connector->username = get_config_string (s->config, "socks_username"); + connector->password = get_config_string (s->config, "socks_password"); - log_server_status (s, s->buffer, "Connecting to #&s via #&s...", - format_host_port_pair (irc_host, irc_port), - format_host_port_pair (socks_host, socks_port)); + connector->user_data = s; + connector->on_connecting = irc_on_socks_connecting; + connector->on_error = irc_on_connector_error; + connector->on_connected = irc_on_connector_connected; + connector->on_failure = irc_on_connector_failure; - // TODO: the SOCKS code needs a rewrite so that we don't block on it either; - // perhaps it could act as a special kind of connector - struct error *error = NULL; - bool result = true; - int fd = socks_connect (socks_host, socks_port, irc_host, irc_port, - socks_username, socks_password, &error); - if (fd != -1) - irc_finish_connection (s, fd); - else + for (size_t i = 0; i < addresses->len; i++) { - error_set (e, "%s: %s", "SOCKS connection failed", error->message); - error_free (error); - result = false; + char *host, *port; + irc_split_host_port (addresses->vector[i], &host, &port); + + if (!socks_connector_add_target (connector, host, port, e)) + return false; } - free (socks_port); - return result; + socks_connector_run (connector); + return true; } +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + static void irc_initiate_connect (struct server *s) { @@ -4497,13 +4519,17 @@ irc_initiate_connect (struct server *s) cstr_split_ignore_empty (addresses, ',', &servers); struct error *e = NULL; - if (!irc_initiate_connect_socks (s, &servers, &e) && !e) + if (!irc_setup_connector_socks (s, &servers, &e) && !e) irc_setup_connector (s, &servers, &e); str_vector_free (&servers); - if (e) + if (!e) + s->state = IRC_CONNECTING; + else { + irc_destroy_connector (s); + log_server_error (s, s->buffer, "#s", e->message); error_free (e); irc_queue_reconnect (s); diff --git a/zyklonb.c b/zyklonb.c index 49d4fc5..991ef13 100644 --- a/zyklonb.c +++ b/zyklonb.c @@ -1647,16 +1647,99 @@ end: irc_reset_connection_timeouts (ctx); } +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// The bot is currently mostly synchronous (which also makes it shorter), +// however our current SOCKS code is not, hence we must wrap it. + +struct irc_socks_data +{ + struct bot_context *ctx; ///< Bot context + struct poller inner_poller; ///< Special inner poller + bool polling; ///< Inner poller is no longer needed + struct socks_connector connector; ///< SOCKS connector + bool succeeded; ///< Were we successful in connecting? +}; + +static void +irc_on_socks_connected (void *user_data, int socket) +{ + struct irc_socks_data *data = user_data; + data->ctx->irc_fd = socket; + data->succeeded = true; + data->polling = true; +} + +static void +irc_on_socks_failure (void *user_data) +{ + struct irc_socks_data *data = user_data; + data->succeeded = false; + data->polling = true; +} + +static void +irc_on_socks_connecting (void *user_data, + const char *address, const char *via, const char *version) +{ + (void) user_data; + print_status ("connecting to %s via %s (%s)...", address, via, version); +} + +static void +irc_on_socks_error (void *user_data, const char *error) +{ + (void) user_data; + print_error ("%s: %s", "SOCKS connection failed", error); +} + +static bool +irc_establish_connection_socks (struct bot_context *ctx, + const char *socks_host, const char *socks_port, + const char *host, const char *service, struct error **e) +{ + struct irc_socks_data data; + struct poller *poller = &data.inner_poller; + struct socks_connector *connector = &data.connector; + + data.ctx = ctx; + poller_init (poller); + data.polling = true; + socks_connector_init (connector, poller); + data.succeeded = false; + + connector->hostname = socks_host; + connector->service = socks_port; + connector->username = str_map_find (&ctx->config, "socks_username"); + connector->password = str_map_find (&ctx->config, "socks_password"); + + connector->on_connected = irc_on_socks_connected; + connector->on_connecting = irc_on_socks_connecting; + connector->on_error = irc_on_socks_error; + connector->on_failure = irc_on_socks_failure; + connector->user_data = &data; + + if (socks_connector_add_target (connector, host, service, e)) + { + socks_connector_run (connector); + while (data.polling) + poller_run (poller); + } + + socks_connector_free (connector); + poller_free (poller); + return data.succeeded; +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + static bool irc_connect (struct bot_context *ctx, struct error **e) { const char *irc_host = str_map_find (&ctx->config, "irc_host"); const char *irc_port = str_map_find (&ctx->config, "irc_port"); - const char *socks_host = str_map_find (&ctx->config, "socks_host"); const char *socks_port = str_map_find (&ctx->config, "socks_port"); - const char *socks_username = str_map_find (&ctx->config, "socks_username"); - const char *socks_password = str_map_find (&ctx->config, "socks_password"); const char *nickname = str_map_find (&ctx->config, "nickname"); const char *username = str_map_find (&ctx->config, "username"); @@ -1678,26 +1761,11 @@ irc_connect (struct bot_context *ctx, struct error **e) if (!irc_get_boolean_from_config (ctx, "ssl", &use_ssl, e)) return false; - if (socks_host) - { - char *address = format_host_port_pair (irc_host, irc_port); - char *socks_address = format_host_port_pair (socks_host, socks_port); - print_status ("connecting to %s via %s...", address, socks_address); - free (socks_address); - free (address); - - struct error *error = NULL; - int fd = socks_connect (socks_host, socks_port, irc_host, irc_port, - socks_username, socks_password, &error); - if (fd == -1) - { - error_set (e, "%s: %s", "SOCKS connection failed", error->message); - error_free (error); - return false; - } - ctx->irc_fd = fd; - } - else if (!irc_establish_connection (ctx, irc_host, irc_port, e)) + bool connected = socks_host + ? irc_establish_connection_socks (ctx, + socks_host, socks_port, irc_host, irc_port, e) + : irc_establish_connection (ctx, irc_host, irc_port, e); + if (!connected) return false; if (use_ssl && !irc_initialize_ssl (ctx, e))