diff --git a/common.c b/common.c index e885945..ae92ecc 100644 --- a/common.c +++ b/common.c @@ -436,11 +436,16 @@ struct socks_connector return false; \ BLOCK_END -#define SOCKS_NEED_DATA(n) \ +#define SOCKS_READ_START(n) \ if (!socks_try_fill_read_buffer (self, (n))) \ return false; \ if (self->read_buffer.len < n) \ - return true + return true; \ + struct msg_unpacker unpacker; \ + msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len) + +#define SOCKS_READ_END \ + str_remove_slice (&self->read_buffer, 0, unpacker.offset) static bool socks_try_fill_read_buffer (struct socks_connector *self, size_t n) @@ -496,15 +501,11 @@ socks_try_flush_write_buffer (struct socks_connector *self) static bool socks_4a_finish (struct socks_connector *self) { - SOCKS_NEED_DATA (8); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); - + SOCKS_READ_START (8); uint8_t null, status; hard_assert (msg_unpacker_u8 (&unpacker, &null)); hard_assert (msg_unpacker_u8 (&unpacker, &status)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; if (null != 0) SOCKS_FAIL ("protocol error"); @@ -578,12 +579,9 @@ socks_4a_start (struct socks_connector *self) static bool socks_5_request_port (struct socks_connector *self) { - SOCKS_NEED_DATA (2); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + SOCKS_READ_START (2); hard_assert (msg_unpacker_u16 (&unpacker, &self->bound_port)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; self->done = true; return false; @@ -593,7 +591,7 @@ static bool socks_5_request_ipv4 (struct socks_connector *self) { size_t len = sizeof self->bound_address.data.ipv4; - SOCKS_NEED_DATA (len); + SOCKS_READ_START (len); memcpy (self->bound_address.data.ipv4, self->read_buffer.str, len); str_remove_slice (&self->read_buffer, 0, len); @@ -605,7 +603,7 @@ static bool socks_5_request_ipv6 (struct socks_connector *self) { size_t len = sizeof self->bound_address.data.ipv6; - SOCKS_NEED_DATA (len); + SOCKS_READ_START (len); memcpy (self->bound_address.data.ipv6, self->read_buffer.str, len); str_remove_slice (&self->read_buffer, 0, len); @@ -617,7 +615,7 @@ static bool socks_5_request_domain_data (struct socks_connector *self) { size_t len = self->bound_address_len; - SOCKS_NEED_DATA (len); + SOCKS_READ_START (len); self->bound_address.data.domain = xstrndup (self->read_buffer.str, len); str_remove_slice (&self->read_buffer, 0, len); @@ -628,12 +626,9 @@ socks_5_request_domain_data (struct socks_connector *self) static bool socks_5_request_domain (struct socks_connector *self) { - SOCKS_NEED_DATA (1); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); + SOCKS_READ_START (1); hard_assert (msg_unpacker_u8 (&unpacker, &self->bound_address_len)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; self->on_data = socks_5_request_domain_data; return true; @@ -642,17 +637,13 @@ socks_5_request_domain (struct socks_connector *self) static bool socks_5_request_finish (struct socks_connector *self) { - SOCKS_NEED_DATA (4); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); - + SOCKS_READ_START (4); uint8_t version, status, reserved, type; hard_assert (msg_unpacker_u8 (&unpacker, &version)); hard_assert (msg_unpacker_u8 (&unpacker, &status)); hard_assert (msg_unpacker_u8 (&unpacker, &reserved)); hard_assert (msg_unpacker_u8 (&unpacker, &type)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; if (version != 0x05) SOCKS_FAIL ("protocol error"); @@ -722,15 +713,11 @@ socks_5_request_start (struct socks_connector *self) static bool socks_5_userpass_finish (struct socks_connector *self) { - SOCKS_NEED_DATA (2); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); - + SOCKS_READ_START (2); uint8_t version, status; hard_assert (msg_unpacker_u8 (&unpacker, &version)); hard_assert (msg_unpacker_u8 (&unpacker, &status)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; if (version != 0x01) SOCKS_FAIL ("protocol error"); @@ -764,15 +751,11 @@ socks_5_userpass_start (struct socks_connector *self) static bool socks_5_auth_finish (struct socks_connector *self) { - SOCKS_NEED_DATA (2); - - struct msg_unpacker unpacker; - msg_unpacker_init (&unpacker, self->read_buffer.str, self->read_buffer.len); - + SOCKS_READ_START (2); uint8_t version, method; hard_assert (msg_unpacker_u8 (&unpacker, &version)); hard_assert (msg_unpacker_u8 (&unpacker, &method)); - str_remove_slice (&self->read_buffer, 0, unpacker.offset); + SOCKS_READ_END; if (version != 0x05) SOCKS_FAIL ("protocol error");