tls-autodetect: put most of the server code in place

So far we act up when it is the client who initializes the shutdown.
This commit is contained in:
Přemysl Eric Janouch 2018-07-15 10:45:12 +02:00
parent b5b64db075
commit 728fa4e548
Signed by: p
GPG Key ID: A0420B94F92B9493

View File

@ -13,12 +13,17 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// //
//
// This is an example TLS-autodetecting chat server. // This is an example TLS-autodetecting chat server.
// //
// You may connect to it either using: // You may connect to it using either of these:
// ncat -C localhost 1234
// ncat -C --ssl localhost 1234
//
// These clients are unable to properly shutdown the connection:
// telnet localhost 1234 // telnet localhost 1234
// or
// openssl s_client -connect localhost:1234 // openssl s_client -connect localhost:1234
//
package main package main
import ( import (
@ -37,6 +42,7 @@ import (
// --- Utilities --------------------------------------------------------------- // --- Utilities ---------------------------------------------------------------
//
// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom // Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
// must be at least three octets long for this to work reliably, but that should // must be at least three octets long for this to work reliably, but that should
// not pose a problem in practice. We might try waiting for them. // not pose a problem in practice. We might try waiting for them.
@ -46,8 +52,7 @@ import (
// SSL3/TLS: <22> | <3> | xxxx xxxx // SSL3/TLS: <22> | <3> | xxxx xxxx
// (handshake)| (protocol version) // (handshake)| (protocol version)
// //
func detectTLS(sysconn syscall.RawConn) bool { func detectTLS(sysconn syscall.RawConn) (isTLS bool) {
isTLS := false
sysconn.Read(func(fd uintptr) (done bool) { sysconn.Read(func(fd uintptr) (done bool) {
var buf [3]byte var buf [3]byte
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
@ -78,30 +83,38 @@ type client struct {
transport net.Conn // underlying connection transport net.Conn // underlying connection
tls *tls.Conn // TLS, if detected tls *tls.Conn // TLS, if detected
conn connCloseWrite // high-level connection conn connCloseWrite // high-level connection
connReady bool // conn is safe to read from the main goroutine
inQ []byte // unprocessed input inQ []byte // unprocessed input
outQ []byte // unprocessed output outQ []byte // unprocessed output
writing bool // whether a writing goroutine is running writing bool // whether a writing goroutine is running
inShutdown bool // whether we're closing connection inShutdown bool // whether we're closing connection
killTimer *time.Timer // timeout
}
type preparedEvent struct {
client *client
host string // client's hostname or literal IP address
isTLS bool // the client seems to use TLS
} }
type readEvent struct { type readEvent struct {
client *client // client client *client
data []byte // new data from the client data []byte // new data from the client
err error // read error err error // read error
} }
type writeEvent struct { type writeEvent struct {
client *client // client client *client
written int // amount of bytes written written int // amount of bytes written
err error // write error err error // write error
} }
var ( var (
sigs = make(chan os.Signal, 1) sigs = make(chan os.Signal, 1)
conns = make(chan net.Conn) conns = make(chan net.Conn)
reads = make(chan readEvent) prepared = make(chan preparedEvent)
writes = make(chan writeEvent) reads = make(chan readEvent)
writes = make(chan writeEvent)
timeouts = make(chan *client)
tlsConf *tls.Config tlsConf *tls.Config
clients = make(map[*client]bool) clients = make(map[*client]bool)
@ -122,6 +135,7 @@ func broadcast(line string, except *client) {
} }
} }
// Initiate a clean shutdown of the whole daemon.
func initiateShutdown() { func initiateShutdown() {
log.Println("shutting down") log.Println("shutting down")
if err := listener.Close(); err != nil { if err := listener.Close(); err != nil {
@ -135,7 +149,12 @@ func initiateShutdown() {
inShutdown = true inShutdown = true
} }
// Forcefully tear down all connections.
func forceShutdown(reason string) { func forceShutdown(reason string) {
if !inShutdown {
log.Fatalln("forceShutdown called without initiateShutdown")
}
log.Printf("forced shutdown (%s)\n", reason) log.Printf("forced shutdown (%s)\n", reason)
for c := range clients { for c := range clients {
c.destroy() c.destroy()
@ -151,67 +170,84 @@ func (c *client) send(line string) {
} }
} }
func (c *client) shutdown() { // Tear down the client connection, trying to do so in a graceful manner.
func (c *client) kill() {
if c.inShutdown { if c.inShutdown {
log.Println("client double shutdown") return
}
if c.conn == nil {
c.destroy()
return return
} }
// TODO: We must set a timer and destroy the client on timeout. Since we // Since we send this goodbye, we don't need to call CloseWrite.
// have a central event loop, we probably need an event. Since we also c.send("Goodbye")
// seem to need an event for TLS autodetection because of conn, we might c.killTimer = time.AfterFunc(3*time.Second, func() {
// want to send an enumeration value. timeouts <- c
c.inShutdown = true })
c.conn.CloseWrite()
}
// Tear down the client connection, trying to do so in a graceful manner. c.inShutdown = true
func (c *client) kill() {
if c.connReady {
c.send("Goodbye")
c.shutdown()
} else {
c.destroy()
}
} }
// Close the connection and forget about the client. // Close the connection and forget about the client.
func (c *client) destroy() { func (c *client) destroy() {
// Try to send a "close notify" alert if the TLS object is ready, // Try to send a "close notify" alert if the TLS object is ready,
// otherwise just tear down the transport. // otherwise just tear down the transport.
if c.connReady { if c.conn != nil {
_ = c.conn.Close() _ = c.conn.Close()
} else { } else {
_ = c.transport.Close() _ = c.transport.Close()
} }
// Clean up the goroutine, although a spurious event may still be sent.
if c.killTimer != nil {
c.killTimer.Stop()
}
delete(clients, c) delete(clients, c)
} }
// Handle the results from initializing the client's connection.
func (c *client) onPrepared(host string, isTLS bool) {
if isTLS {
c.tls = tls.Server(c.transport, tlsConf)
c.conn = c.tls
} else {
c.conn = c.transport.(connCloseWrite)
}
// TODO: Save the host in the client structure.
go read(c)
}
// Handle the results from trying to read from the client connection. // Handle the results from trying to read from the client connection.
func (c *client) onRead(data []byte, readErr error) { func (c *client) onRead(data []byte, readErr error) {
c.inQ = append(c.inQ, data...) c.inQ = append(c.inQ, data...)
for { for {
advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
c.inQ = c.inQ[advance:]
if advance == 0 { if advance == 0 {
break break
} }
c.inQ = c.inQ[advance:]
line := string(token) line := string(token)
fmt.Println(line) fmt.Println(line)
broadcast(line, c) broadcast(line, c)
} }
// TODO: Inform the client about the inQ overrun in the farewell message. // TODO: Inform the client about the inQ overrun in the farewell message.
// TODO: We should stop receiving any more data from this client.
if len(c.inQ) > 8192 { if len(c.inQ) > 8192 {
c.kill() c.kill()
return return
} }
if readErr == io.EOF { if readErr == io.EOF {
// TODO: What if we're already in shutdown? if c.inShutdown {
c.shutdown() c.destroy()
} else {
c.kill()
}
} else if readErr != nil { } else if readErr != nil {
log.Println(readErr) log.Println(readErr)
c.destroy() c.destroy()
@ -221,7 +257,7 @@ func (c *client) onRead(data []byte, readErr error) {
// Spawn a goroutine to flush the outQ if possible and necessary. If the // Spawn a goroutine to flush the outQ if possible and necessary. If the
// connection is not ready yet, it needs to be retried as soon as it becomes. // connection is not ready yet, it needs to be retried as soon as it becomes.
func (c *client) flushOutQ() { func (c *client) flushOutQ() {
if c.connReady && !c.writing { if c.conn != nil && !c.writing {
go write(c, c.outQ) go write(c, c.outQ)
c.writing = true c.writing = true
} }
@ -238,41 +274,77 @@ func (c *client) onWrite(written int, writeErr error) {
} else if len(c.outQ) > 0 { } else if len(c.outQ) > 0 {
c.flushOutQ() c.flushOutQ()
} else if c.inShutdown { } else if c.inShutdown {
c.destroy() if c.conn != nil {
// FIXME: This is only correct for when /we/ initiate the shutdown,
// otherwise we should perhaps just Close. Though even if we
// Close, there's a/ no writer to fail on it, and b/ the reader
// has already exited, too, which is why the client stays alive
// up until the timeout. It seems that in that case we need to
// call c.destroy().
c.conn.CloseWrite()
} else {
c.destroy()
}
} }
} }
// --- Worker goroutines ------------------------------------------------------- // --- Worker goroutines -------------------------------------------------------
func accept(ln net.Listener) { func accept(ln net.Listener) {
// TODO: Consider specific cases in error handling, some errors
// are transitional while others are fatal.
for { for {
if conn, err := ln.Accept(); err != nil { if conn, err := ln.Accept(); err != nil {
// TODO: Consider specific cases in error handling, some errors
// are transitional while others are fatal.
log.Println(err) log.Println(err)
break
} else { } else {
conns <- conn conns <- conn
} }
} }
} }
func read(client *client) { func prepare(client *client) {
// TODO: Either here or elsewhere we need to set a timeout. conn := client.transport
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
client.conn = client.transport.(connCloseWrite) if err != nil {
if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil { // In effect, we require TCP/UDP, as they have port numbers.
// This is just for the TLS detection and doesn't need to be fatal. log.Fatalln(err)
log.Println(err)
} else if detectTLS(sysconn) {
client.tls = tls.Server(client.transport, tlsConf)
client.conn = client.tls
} }
// TODO: Signal the main goroutine that conn is ready. In fact, the upper // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not
// part could be mostly moved to the main goroutine and we'd only spawn // bothering with pointless contexts.
// a thin wrapper around detectTLS, sending back {*client, bool}. Heck, ch := make(chan string)
// I could get rid of connReady. go func() {
defer close(ch)
if names, err := net.LookupAddr(host); err != nil {
log.Println(err)
} else {
ch <- names[0]
}
}()
// While we can't cancel it, we still want to set a timeout on it.
select {
case <-time.After(5 * time.Second):
case resolved, ok := <-ch:
if ok {
host = resolved
}
}
isTLS := false
if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil {
// This is just for the TLS detection and doesn't need to be fatal.
log.Println(err)
} else {
isTLS = detectTLS(sysconn)
}
prepared <- preparedEvent{client, host, isTLS}
}
func read(client *client) {
// A new buffer is allocated each time we receive some bytes, because of // A new buffer is allocated each time we receive some bytes, because of
// thread-safety. Therefore the buffer shouldn't be too large, or we'd // thread-safety. Therefore the buffer shouldn't be too large, or we'd
// need to copy it each time into a precisely sized new buffer. // need to copy it each time into a precisely sized new buffer.
@ -312,7 +384,13 @@ func processOneEvent() {
log.Println("accepted client connection") log.Println("accepted client connection")
c := &client{transport: conn} c := &client{transport: conn}
clients[c] = true clients[c] = true
go read(c) go prepare(c)
case ev := <-prepared:
log.Println("client is ready:", ev.host)
if _, ok := clients[ev.client]; ok {
ev.client.onPrepared(ev.host, ev.isTLS)
}
case ev := <-reads: case ev := <-reads:
log.Println("received data from client") log.Println("received data from client")
@ -325,6 +403,12 @@ func processOneEvent() {
if _, ok := clients[ev.client]; ok { if _, ok := clients[ev.client]; ok {
ev.client.onWrite(ev.written, ev.err) ev.client.onWrite(ev.written, ev.err)
} }
case c := <-timeouts:
if _, ok := clients[c]; ok {
log.Println("client timeouted")
c.destroy()
}
} }
} }