Finish the WebSocket backend

Of course, everything so far hasn't been tested much.
This commit is contained in:
Přemysl Eric Janouch 2018-10-18 02:55:55 +02:00
parent 580f0a0c59
commit 62945cceb3
Signed by: p
GPG Key ID: A0420B94F92B9493
1 changed files with 164 additions and 111 deletions

View File

@ -540,9 +540,9 @@ enum ws_handler_state
{ {
WS_HANDLER_CONNECTING, ///< Parsing HTTP WS_HANDLER_CONNECTING, ///< Parsing HTTP
WS_HANDLER_OPEN, ///< Parsing WebSockets frames WS_HANDLER_OPEN, ///< Parsing WebSockets frames
WS_HANDLER_CLOSING, ///< Closing the connection WS_HANDLER_CLOSING, ///< Partial closure by us
WS_HANDLER_ALMOST_DEAD, ///< Closing connection after failure WS_HANDLER_FLUSHING, ///< Just waiting for client EOF
WS_HANDLER_CLOSED ///< Dead WS_HANDLER_CLOSED ///< Dead, both sides closed
}; };
struct ws_handler struct ws_handler
@ -584,6 +584,7 @@ struct ws_handler
// TODO: void (*on_handshake) (protocols) that will allow the user // TODO: void (*on_handshake) (protocols) that will allow the user
// to choose any sub-protocol, if the client has provided any. // to choose any sub-protocol, if the client has provided any.
// This may render "on_connected" unnecessary. // This may render "on_connected" unnecessary.
// Should also enable failing the handshake.
/// Called after successfuly connecting (handshake complete) /// Called after successfuly connecting (handshake complete)
bool (*on_connected) (struct ws_handler *); bool (*on_connected) (struct ws_handler *);
@ -626,36 +627,41 @@ static void
ws_handler_close (struct ws_handler *self, ws_handler_close (struct ws_handler *self,
enum ws_status close_code, const char *reason, size_t len) enum ws_status close_code, const char *reason, size_t len)
{ {
hard_assert (self->state == WS_HANDLER_OPEN);
struct str payload = str_make (); struct str payload = str_make ();
str_pack_u16 (&payload, close_code); str_pack_u16 (&payload, close_code);
// XXX: maybe accept a null-terminated string on input? Has to be UTF-8 a/w // XXX: maybe accept a null-terminated string on input? Has to be UTF-8 a/w
str_append_data (&payload, reason, len); str_append_data (&payload, reason, len);
ws_handler_send_control (self, WS_OPCODE_CLOSE, payload.str, payload.len); ws_handler_send_control (self, WS_OPCODE_CLOSE, payload.str, payload.len);
self->close_cb (self, true /* half_close */);
// Close initiated by us; the reason is null-terminated within `payload'
if (self->on_close)
self->on_close (self, close_code, payload.str + 2);
self->state = WS_HANDLER_CLOSING; self->state = WS_HANDLER_CLOSING;
str_free (&payload); str_free (&payload);
} }
static void static bool
ws_handler_fail (struct ws_handler *self, enum ws_status close_code) ws_handler_fail_connection (struct ws_handler *self, enum ws_status close_code)
{ {
ws_handler_close (self, close_code, NULL, 0); hard_assert (self->state == WS_HANDLER_OPEN
self->state = WS_HANDLER_ALMOST_DEAD; || self->state == WS_HANDLER_CLOSING);
// TODO: set the close timer, ignore all further incoming input (either set if (self->state == WS_HANDLER_OPEN)
// some flag for the case that we're in the middle of ws_handler_push(), ws_handler_close (self, close_code, NULL, 0);
// and/or add a mechanism to stop the caller from polling the socket for
// reads). self->state = WS_HANDLER_FLUSHING;
// TODO: make sure we don't send pings after the close if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "");
ev_timer_stop (EV_DEFAULT_ &self->ping_timer);
ev_timer_set (&self->close_timeout_watcher, self->close_timeout, 0.);
ev_timer_start (EV_DEFAULT_ &self->close_timeout_watcher);
return false;
} }
// TODO: add support for fragmented responses // TODO: add support for fragmented responses
static void static void
ws_handler_send (struct ws_handler *self, ws_handler_send_frame (struct ws_handler *self,
enum ws_opcode opcode, const void *data, size_t len) enum ws_opcode opcode, const void *data, size_t len)
{ {
if (!soft_assert (self->state == WS_HANDLER_OPEN)) if (!soft_assert (self->state == WS_HANDLER_OPEN))
@ -697,24 +703,26 @@ ws_handler_on_frame_header (void *user_data, const struct ws_parser *parser)
|| (!ws_is_control_frame (parser->opcode) && || (!ws_is_control_frame (parser->opcode) &&
(self->expecting_continuation && parser->opcode != WS_OPCODE_CONT)) (self->expecting_continuation && parser->opcode != WS_OPCODE_CONT))
|| parser->payload_len >= 0x8000000000000000ULL) || parser->payload_len >= 0x8000000000000000ULL)
ws_handler_fail (self, WS_STATUS_PROTOCOL_ERROR); return ws_handler_fail_connection (self, WS_STATUS_PROTOCOL_ERROR);
else if (parser->payload_len > self->max_payload_len)
ws_handler_fail (self, WS_STATUS_MESSAGE_TOO_BIG); if (parser->payload_len > self->max_payload_len
else || (self->expecting_continuation &&
self->message_data.len + parser->payload_len > self->max_payload_len))
return ws_handler_fail_connection (self, WS_STATUS_MESSAGE_TOO_BIG);
return true; return true;
return false;
} }
static bool static bool
ws_handler_on_protocol_close ws_handler_on_control_close
(struct ws_handler *self, const struct ws_parser *parser) (struct ws_handler *self, const struct ws_parser *parser)
{ {
hard_assert (self->state == WS_HANDLER_OPEN
|| self->state == WS_HANDLER_CLOSING);
struct msg_unpacker unpacker = struct msg_unpacker unpacker =
msg_unpacker_make (parser->input.str, parser->payload_len); msg_unpacker_make (parser->input.str, parser->payload_len);
char *reason = NULL; char *reason = NULL;
uint16_t close_code = WS_STATUS_NO_STATUS_RECEIVED; uint16_t close_code = WS_STATUS_NO_STATUS_RECEIVED;
if (parser->payload_len >= 2) if (parser->payload_len >= 2)
{ {
(void) msg_unpacker_u16 (&unpacker, &close_code); (void) msg_unpacker_u16 (&unpacker, &close_code);
@ -723,17 +731,29 @@ ws_handler_on_protocol_close
else else
reason = xstrdup (""); reason = xstrdup ("");
if (self->state != WS_HANDLER_CLOSING) if (close_code < 1000 || close_code > 4999)
// XXX: invalid close code: maybe we should fail the connection instead
close_code = WS_STATUS_PROTOCOL_ERROR;
if (self->state == WS_HANDLER_OPEN)
{ {
// Close initiated by the client // Close initiated by the client
// FIXME: not sending the potentially different close_code
ws_handler_send_control (self, WS_OPCODE_CLOSE, ws_handler_send_control (self, WS_OPCODE_CLOSE,
parser->input.str, parser->payload_len); parser->input.str, parser->payload_len);
self->state = WS_HANDLER_FLUSHING;
if (self->on_close) if (self->on_close)
self->on_close (self, close_code, reason); self->on_close (self, close_code, reason);
} }
else
self->state = WS_HANDLER_FLUSHING;
free (reason); free (reason);
self->state = WS_HANDLER_ALMOST_DEAD;
ev_timer_stop (EV_DEFAULT_ &self->ping_timer);
ev_timer_set (&self->close_timeout_watcher, self->close_timeout, 0.);
ev_timer_start (EV_DEFAULT_ &self->close_timeout_watcher);
return true; return true;
} }
@ -744,21 +764,18 @@ ws_handler_on_control_frame
switch (parser->opcode) switch (parser->opcode)
{ {
case WS_OPCODE_CLOSE: case WS_OPCODE_CLOSE:
return ws_handler_on_protocol_close (self, parser); return ws_handler_on_control_close (self, parser);
case WS_OPCODE_PING: case WS_OPCODE_PING:
ws_handler_send_control (self, WS_OPCODE_PONG, ws_handler_send_control (self, WS_OPCODE_PONG,
parser->input.str, parser->payload_len); parser->input.str, parser->payload_len);
break; break;
case WS_OPCODE_PONG: case WS_OPCODE_PONG:
// XXX: maybe we should check the payload // TODO: check the payload
self->received_pong = true; self->received_pong = true;
break; break;
default: default:
// Unknown control frame // Unknown control frame
ws_handler_fail (self, WS_STATUS_PROTOCOL_ERROR); return ws_handler_fail_connection (self, WS_STATUS_PROTOCOL_ERROR);
// FIXME: we shouldn't close the connection right away;
// also check other places
return false;
} }
return true; return true;
} }
@ -769,29 +786,19 @@ ws_handler_on_frame (void *user_data, const struct ws_parser *parser)
struct ws_handler *self = user_data; struct ws_handler *self = user_data;
if (ws_is_control_frame (parser->opcode)) if (ws_is_control_frame (parser->opcode))
return ws_handler_on_control_frame (self, parser); return ws_handler_on_control_frame (self, parser);
// TODO: do this rather in "on_frame_header"
if (self->message_data.len + parser->payload_len > self->max_payload_len)
{
ws_handler_fail (self, WS_STATUS_MESSAGE_TOO_BIG);
return false;
}
if (!self->expecting_continuation) if (!self->expecting_continuation)
self->message_opcode = parser->opcode; self->message_opcode = parser->opcode;
str_append_data (&self->message_data, str_append_data (&self->message_data,
parser->input.str, parser->payload_len); parser->input.str, parser->payload_len);
self->expecting_continuation = !parser->is_fin; if ((self->expecting_continuation = !parser->is_fin))
if (!parser->is_fin)
return true; return true;
if (self->message_opcode == WS_OPCODE_TEXT if (self->message_opcode == WS_OPCODE_TEXT
&& !utf8_validate (self->message_data.str, self->message_data.len)) && !utf8_validate (self->message_data.str, self->message_data.len))
{ {
ws_handler_fail (self, WS_STATUS_INVALID_PAYLOAD_DATA); return ws_handler_fail_connection
return false; (self, WS_STATUS_INVALID_PAYLOAD_DATA);
} }
bool result = true; bool result = true;
@ -799,6 +806,8 @@ ws_handler_on_frame (void *user_data, const struct ws_parser *parser)
result = self->on_message (self, self->message_opcode, result = self->on_message (self, self->message_opcode,
self->message_data.str, self->message_data.len); self->message_data.str, self->message_data.len);
str_reset (&self->message_data); str_reset (&self->message_data);
// TODO: if (!result), either replace this with a state check,
// or make sure to change the state
return result; return result;
} }
@ -810,11 +819,10 @@ ws_handler_on_ping_timer (EV_P_ ev_timer *watcher, int revents)
struct ws_handler *self = watcher->data; struct ws_handler *self = watcher->data;
if (!self->received_pong) if (!self->received_pong)
{ ws_handler_fail_connection (self, 4000);
// TODO: close/fail the connection?
}
else else
{ {
// TODO: be an annoying server and send a nonce in the data
ws_handler_send_control (self, WS_OPCODE_PING, NULL, 0); ws_handler_send_control (self, WS_OPCODE_PING, NULL, 0);
ev_timer_again (EV_A_ watcher); ev_timer_again (EV_A_ watcher);
} }
@ -823,20 +831,38 @@ ws_handler_on_ping_timer (EV_P_ ev_timer *watcher, int revents)
static void static void
ws_handler_on_close_timeout (EV_P_ ev_timer *watcher, int revents) ws_handler_on_close_timeout (EV_P_ ev_timer *watcher, int revents)
{ {
(void) loop;
(void) revents; (void) revents;
struct ws_handler *self = watcher->data; struct ws_handler *self = watcher->data;
// TODO: anything else to do here? Invalidate our state? hard_assert (self->state == WS_HANDLER_OPEN
if (self->close_cb) || self->state == WS_HANDLER_CLOSING);
if (self->state == WS_HANDLER_CLOSING
&& self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "close timeout");
self->state = WS_HANDLER_CLOSED;
self->close_cb (self, false /* half_close */); self->close_cb (self, false /* half_close */);
} }
static void static void
ws_handler_on_handshake_timeout (EV_P_ ev_timer *watcher, int revents) ws_handler_on_handshake_timeout (EV_P_ ev_timer *watcher, int revents)
{ {
(void) loop;
(void) revents; (void) revents;
struct ws_handler *self = watcher->data; struct ws_handler *self = watcher->data;
// TODO
// XXX: this is a no-op, since this currently doesn't even call shutdown
// immediately but postpones it until later
self->close_cb (self, true /* half_close */);
self->state = WS_HANDLER_FLUSHING;
if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "handshake timeout");
self->state = WS_HANDLER_CLOSED;
self->close_cb (self, false /* half_close */);
} }
static void static void
@ -991,6 +1017,7 @@ ws_handler_on_url (http_parser *parser, const char *at, size_t len)
#define HTTP_400_BAD_REQUEST "400 Bad Request" #define HTTP_400_BAD_REQUEST "400 Bad Request"
#define HTTP_405_METHOD_NOT_ALLOWED "405 Method Not Allowed" #define HTTP_405_METHOD_NOT_ALLOWED "405 Method Not Allowed"
#define HTTP_417_EXPECTATION_FAILED "407 Expectation Failed" #define HTTP_417_EXPECTATION_FAILED "407 Expectation Failed"
#define HTTP_426_UPGRADE_REQUIRED "426 Upgrade Required"
#define HTTP_505_VERSION_NOT_SUPPORTED "505 HTTP Version Not Supported" #define HTTP_505_VERSION_NOT_SUPPORTED "505 HTTP Version Not Supported"
static void static void
@ -1024,43 +1051,47 @@ ws_handler_http_responsev (struct ws_handler *self,
str_free (&response); str_free (&response);
} }
static void static bool
ws_handler_http_response (struct ws_handler *self, const char *status, ...) ws_handler_fail_handshake (struct ws_handler *self, const char *status, ...)
{ {
struct strv v = strv_make ();
va_list ap; va_list ap;
va_start (ap, status); va_start (ap, status);
const char *s; const char *s;
struct strv v = strv_make ();
while ((s = va_arg (ap, const char *))) while ((s = va_arg (ap, const char *)))
strv_append (&v, s); strv_append (&v, s);
va_end (ap); va_end (ap);
ws_handler_http_responsev (self, status, v.vector); ws_handler_http_responsev (self, status, v.vector);
strv_free (&v); strv_free (&v);
self->close_cb (self, true /* half_close */);
self->state = WS_HANDLER_FLUSHING;
if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, status);
return false;
} }
#define FAIL_HANDSHAKE(status, ...) \ #define FAIL_HANDSHAKE(...) \
BLOCK_START \ return ws_handler_fail_handshake (self, __VA_ARGS__, NULL)
self->state = WS_HANDLER_ALMOST_DEAD; \
ws_handler_http_response (self, (status), __VA_ARGS__); \
return false; \
BLOCK_END
static bool static bool
ws_handler_finish_handshake (struct ws_handler *self) ws_handler_finish_handshake (struct ws_handler *self)
{ {
// XXX: we probably shouldn't use 505 to reject the minor version but w/e
if (self->hp.http_major != 1 || self->hp.http_minor < 1)
FAIL_HANDSHAKE (HTTP_505_VERSION_NOT_SUPPORTED, NULL);
if (self->hp.method != HTTP_GET) if (self->hp.method != HTTP_GET)
FAIL_HANDSHAKE (HTTP_405_METHOD_NOT_ALLOWED, "Allow: GET", NULL); FAIL_HANDSHAKE (HTTP_405_METHOD_NOT_ALLOWED, "Allow: GET");
// Technically, it must be /at least/ 1.1 but no other 1.x version of HTTP
// is going to happen and 2.x is entirely incompatible
// XXX: we probably shouldn't use 505 to reject the minor version but w/e
if (self->hp.http_major != 1 || self->hp.http_minor != 1)
FAIL_HANDSHAKE (HTTP_505_VERSION_NOT_SUPPORTED);
// Your expectations are way too high // Your expectations are way too high
if (str_map_find (&self->headers, "Expect")) if (str_map_find (&self->headers, "Expect"))
FAIL_HANDSHAKE (HTTP_417_EXPECTATION_FAILED, NULL); FAIL_HANDSHAKE (HTTP_417_EXPECTATION_FAILED);
// Reject URLs specifying the schema and host; we're not parsing that // Reject URLs specifying the schema and host; we're not parsing that
// TODO: actually do parse this and let our user decide if it matches // TODO: actually do parse this and let our user decide if it matches
@ -1068,11 +1099,11 @@ ws_handler_finish_handshake (struct ws_handler *self)
if (http_parser_parse_url (self->url.str, self->url.len, false, &url) if (http_parser_parse_url (self->url.str, self->url.len, false, &url)
|| (url.field_set & (1 << UF_SCHEMA | 1 << UF_HOST | 1 << UF_PORT)) || (url.field_set & (1 << UF_SCHEMA | 1 << UF_HOST | 1 << UF_PORT))
|| !str_map_find (&self->headers, "Host")) || !str_map_find (&self->headers, "Host"))
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
const char *connection = str_map_find (&self->headers, "Connection"); const char *connection = str_map_find (&self->headers, "Connection");
if (!connection || strcasecmp_ascii (connection, "Upgrade")) if (!connection || strcasecmp_ascii (connection, "Upgrade"))
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
// Check if we can actually upgrade the protocol to WebSockets // Check if we can actually upgrade the protocol to WebSockets
const char *upgrade = str_map_find (&self->headers, "Upgrade"); const char *upgrade = str_map_find (&self->headers, "Upgrade");
@ -1088,7 +1119,8 @@ ws_handler_finish_handshake (struct ws_handler *self)
http_protocol_destroy (iter); http_protocol_destroy (iter);
} }
if (!can_upgrade) if (!can_upgrade)
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_426_UPGRADE_REQUIRED,
"Upgrade: websocket", SEC_WS_VERSION ": 13");
// Okay, we're finally past the basic HTTP/1.1 stuff // Okay, we're finally past the basic HTTP/1.1 stuff
const char *key = str_map_find (&self->headers, SEC_WS_KEY); const char *key = str_map_find (&self->headers, SEC_WS_KEY);
@ -1098,19 +1130,17 @@ ws_handler_finish_handshake (struct ws_handler *self)
const char *extensions = str_map_find (&self->headers, SEC_WS_EXTENSIONS); const char *extensions = str_map_find (&self->headers, SEC_WS_EXTENSIONS);
*/ */
if (!key) if (!version)
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
if (strcmp (version, "13"))
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, SEC_WS_VERSION ": 13");
struct str tmp = str_make (); struct str tmp = str_make ();
bool key_is_valid = base64_decode (key, false, &tmp) && tmp.len == 16; bool key_is_valid = key
&& base64_decode (key, false, &tmp) && tmp.len == 16;
str_free (&tmp); str_free (&tmp);
if (!key_is_valid) if (!key_is_valid)
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
if (!version)
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL);
if (strcmp (version, "13"))
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, SEC_WS_VERSION ": 13", NULL);
struct strv fields = strv_make (); struct strv fields = strv_make ();
strv_append_args (&fields, strv_append_args (&fields,
@ -1130,6 +1160,7 @@ ws_handler_finish_handshake (struct ws_handler *self)
strv_free (&fields); strv_free (&fields);
self->state = WS_HANDLER_OPEN;
ev_timer_init (&self->ping_timer, ws_handler_on_ping_timer, ev_timer_init (&self->ping_timer, ws_handler_on_ping_timer,
self->ping_interval, 0); self->ping_interval, 0);
ev_timer_start (EV_DEFAULT_ &self->ping_timer); ev_timer_start (EV_DEFAULT_ &self->ping_timer);
@ -1141,40 +1172,62 @@ ws_handler_finish_handshake (struct ws_handler *self)
static void static void
ws_handler_start (struct ws_handler *self) ws_handler_start (struct ws_handler *self)
{ {
hard_assert (self->state == WS_HANDLER_CONNECTING);
ev_timer_set (&self->handshake_timeout_watcher, ev_timer_set (&self->handshake_timeout_watcher,
self->handshake_timeout, 0.); self->handshake_timeout, 0.);
ev_timer_start (EV_DEFAULT_ &self->handshake_timeout_watcher); ev_timer_start (EV_DEFAULT_ &self->handshake_timeout_watcher);
} }
// The client should normally never close the connection, assume that it's
// either received an EOF from our side, or that it doesn't care about our data
// anymore, having called close() already
static bool
ws_handler_push_eof (struct ws_handler *self)
{
switch (self->state)
{
case WS_HANDLER_CONNECTING:
ev_timer_stop (EV_DEFAULT_ &self->handshake_timeout_watcher);
self->state = WS_HANDLER_FLUSHING;
if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "unexpected EOF");
break;
case WS_HANDLER_OPEN:
ev_timer_stop (EV_DEFAULT_ &self->ping_timer);
// Fall-through
case WS_HANDLER_CLOSING:
self->state = WS_HANDLER_CLOSED;
if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "");
// Fall-through
case WS_HANDLER_FLUSHING:
ev_timer_stop (EV_DEFAULT_ &self->close_timeout_watcher);
break;
default:
soft_assert(self->state != WS_HANDLER_CLOSED);
}
self->state = WS_HANDLER_CLOSED;
return false;
}
/// Push data to the WebSocket handler. "len == 0" means EOF. /// Push data to the WebSocket handler. "len == 0" means EOF.
/// You are expected to close the connection and dispose of the handler
/// when the function returns false.
static bool static bool
ws_handler_push (struct ws_handler *self, const void *data, size_t len) ws_handler_push (struct ws_handler *self, const void *data, size_t len)
{ {
// TODO: make sure all timers are stopped appropriately
if (!len) if (!len)
{ return ws_handler_push_eof (self);
ev_timer_stop (EV_DEFAULT_ &self->handshake_timeout_watcher);
if (self->state == WS_HANDLER_OPEN) if (self->state == WS_HANDLER_FLUSHING)
{
if (self->on_close)
self->on_close (self, WS_STATUS_ABNORMAL_CLOSURE, "");
}
else
{
// TODO: anything to do besides just closing the connection?
}
self->state = WS_HANDLER_CLOSED;
return false;
}
if (self->state == WS_HANDLER_ALMOST_DEAD)
// We're waiting for an EOF from the client, must not process data // We're waiting for an EOF from the client, must not process data
return true; return true;
if (self->state != WS_HANDLER_CONNECTING) if (self->state != WS_HANDLER_CONNECTING)
return ws_parser_push (&self->parser, data, len); return soft_assert (self->state != WS_HANDLER_CLOSED)
&& ws_parser_push (&self->parser, data, len);
// The handshake hasn't been done yet, process HTTP headers // The handshake hasn't been done yet, process HTTP headers
static const http_parser_settings http_settings = static const http_parser_settings http_settings =
@ -1185,8 +1238,8 @@ ws_handler_push (struct ws_handler *self, const void *data, size_t len)
.on_url = ws_handler_on_url, .on_url = ws_handler_on_url,
}; };
size_t n_parsed = http_parser_execute (&self->hp, size_t n_parsed =
&http_settings, data, len); http_parser_execute (&self->hp, &http_settings, data, len);
if (self->hp.upgrade) if (self->hp.upgrade)
{ {
@ -1195,12 +1248,10 @@ ws_handler_push (struct ws_handler *self, const void *data, size_t len)
// The handshake hasn't been finished, yet there is more data // The handshake hasn't been finished, yet there is more data
// to be processed after the headers already // to be processed after the headers already
if (len - n_parsed) if (len - n_parsed)
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
if (!ws_handler_finish_handshake (self)) if (!ws_handler_finish_handshake (self))
return false; return false;
self->state = WS_HANDLER_OPEN;
if (self->on_connected) if (self->on_connected)
return self->on_connected (self); return self->on_connected (self);
return true; return true;
@ -1217,7 +1268,7 @@ ws_handler_push (struct ws_handler *self, const void *data, size_t len)
print_debug ("WS handshake failed: %s", print_debug ("WS handshake failed: %s",
http_errno_description (err)); http_errno_description (err));
FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST, NULL); FAIL_HANDSHAKE (HTTP_400_BAD_REQUEST);
} }
return true; return true;
} }
@ -2319,15 +2370,15 @@ client_ws_on_message (struct ws_handler *handler,
FIND_CONTAINER (self, handler, struct client_ws, handler); FIND_CONTAINER (self, handler, struct client_ws, handler);
if (type != WS_OPCODE_TEXT) if (type != WS_OPCODE_TEXT)
{ {
ws_handler_fail (&self->handler, WS_STATUS_UNSUPPORTED_DATA); return ws_handler_fail_connection
return false; (&self->handler, WS_STATUS_UNSUPPORTED_DATA);
} }
struct server_context *ctx = ev_userdata (EV_DEFAULT); struct server_context *ctx = ev_userdata (EV_DEFAULT);
struct str response = str_make (); struct str response = str_make ();
process_json_rpc (ctx, data, len, &response); process_json_rpc (ctx, data, len, &response);
if (response.len) if (response.len)
ws_handler_send (&self->handler, ws_handler_send_frame (&self->handler,
WS_OPCODE_TEXT, response.str, response.len); WS_OPCODE_TEXT, response.str, response.len);
str_free (&response); str_free (&response);
return true; return true;
@ -2353,6 +2404,7 @@ static bool
client_ws_push (struct client *client, const void *data, size_t len) client_ws_push (struct client *client, const void *data, size_t len)
{ {
FIND_CONTAINER (self, client, struct client_ws, client); FIND_CONTAINER (self, client, struct client_ws, client);
// client_close() will correctly destroy the client on EOF
return ws_handler_push (&self->handler, data, len); return ws_handler_push (&self->handler, data, len);
} }
@ -2361,7 +2413,8 @@ client_ws_shutdown (struct client *client)
{ {
FIND_CONTAINER (self, client, struct client_ws, client); FIND_CONTAINER (self, client, struct client_ws, client);
if (self->handler.state == WS_HANDLER_CONNECTING) if (self->handler.state == WS_HANDLER_CONNECTING)
; // TODO: abort the connection immediately // No on_close, no problem
client_destroy (&self->client);
else if (self->handler.state == WS_HANDLER_OPEN) else if (self->handler.state == WS_HANDLER_OPEN)
ws_handler_close (&self->handler, WS_STATUS_GOING_AWAY, NULL, 0); ws_handler_close (&self->handler, WS_STATUS_GOING_AWAY, NULL, 0);
} }