tls-autodetect: put most of the server code in place
So far we act up when it is the client who initializes the shutdown.
This commit is contained in:
		@@ -13,12 +13,17 @@
 | 
			
		||||
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// This is an example TLS-autodetecting chat server.
 | 
			
		||||
//
 | 
			
		||||
// You may connect to it either using:
 | 
			
		||||
// You may connect to it using either of these:
 | 
			
		||||
//  ncat -C localhost 1234
 | 
			
		||||
//  ncat -C --ssl localhost 1234
 | 
			
		||||
//
 | 
			
		||||
// These clients are unable to properly shutdown the connection:
 | 
			
		||||
//  telnet localhost 1234
 | 
			
		||||
// or
 | 
			
		||||
//  openssl s_client -connect localhost:1234
 | 
			
		||||
//
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
@@ -37,6 +42,7 @@ import (
 | 
			
		||||
 | 
			
		||||
// --- Utilities ---------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
 | 
			
		||||
// must be at least three octets long for this to work reliably, but that should
 | 
			
		||||
// not pose a problem in practice. We might try waiting for them.
 | 
			
		||||
@@ -46,8 +52,7 @@ import (
 | 
			
		||||
//  SSL3/TLS:    <22>    |    <3>    | xxxx xxxx
 | 
			
		||||
//            (handshake)|  (protocol version)
 | 
			
		||||
//
 | 
			
		||||
func detectTLS(sysconn syscall.RawConn) bool {
 | 
			
		||||
	isTLS := false
 | 
			
		||||
func detectTLS(sysconn syscall.RawConn) (isTLS bool) {
 | 
			
		||||
	sysconn.Read(func(fd uintptr) (done bool) {
 | 
			
		||||
		var buf [3]byte
 | 
			
		||||
		n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
 | 
			
		||||
@@ -78,30 +83,38 @@ type client struct {
 | 
			
		||||
	transport  net.Conn       // underlying connection
 | 
			
		||||
	tls        *tls.Conn      // TLS, if detected
 | 
			
		||||
	conn       connCloseWrite // high-level connection
 | 
			
		||||
	connReady  bool           // conn is safe to read from the main goroutine
 | 
			
		||||
	inQ        []byte         // unprocessed input
 | 
			
		||||
	outQ       []byte         // unprocessed output
 | 
			
		||||
	writing    bool           // whether a writing goroutine is running
 | 
			
		||||
	inShutdown bool           // whether we're closing connection
 | 
			
		||||
	killTimer  *time.Timer    // timeout
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type preparedEvent struct {
 | 
			
		||||
	client *client
 | 
			
		||||
	host   string // client's hostname or literal IP address
 | 
			
		||||
	isTLS  bool   // the client seems to use TLS
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type readEvent struct {
 | 
			
		||||
	client *client // client
 | 
			
		||||
	data   []byte  // new data from the client
 | 
			
		||||
	err    error   // read error
 | 
			
		||||
	client *client
 | 
			
		||||
	data   []byte // new data from the client
 | 
			
		||||
	err    error  // read error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type writeEvent struct {
 | 
			
		||||
	client  *client // client
 | 
			
		||||
	written int     // amount of bytes written
 | 
			
		||||
	err     error   // write error
 | 
			
		||||
	client  *client
 | 
			
		||||
	written int   // amount of bytes written
 | 
			
		||||
	err     error // write error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	sigs   = make(chan os.Signal, 1)
 | 
			
		||||
	conns  = make(chan net.Conn)
 | 
			
		||||
	reads  = make(chan readEvent)
 | 
			
		||||
	writes = make(chan writeEvent)
 | 
			
		||||
	sigs     = make(chan os.Signal, 1)
 | 
			
		||||
	conns    = make(chan net.Conn)
 | 
			
		||||
	prepared = make(chan preparedEvent)
 | 
			
		||||
	reads    = make(chan readEvent)
 | 
			
		||||
	writes   = make(chan writeEvent)
 | 
			
		||||
	timeouts = make(chan *client)
 | 
			
		||||
 | 
			
		||||
	tlsConf       *tls.Config
 | 
			
		||||
	clients       = make(map[*client]bool)
 | 
			
		||||
@@ -122,6 +135,7 @@ func broadcast(line string, except *client) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Initiate a clean shutdown of the whole daemon.
 | 
			
		||||
func initiateShutdown() {
 | 
			
		||||
	log.Println("shutting down")
 | 
			
		||||
	if err := listener.Close(); err != nil {
 | 
			
		||||
@@ -135,7 +149,12 @@ func initiateShutdown() {
 | 
			
		||||
	inShutdown = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Forcefully tear down all connections.
 | 
			
		||||
func forceShutdown(reason string) {
 | 
			
		||||
	if !inShutdown {
 | 
			
		||||
		log.Fatalln("forceShutdown called without initiateShutdown")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("forced shutdown (%s)\n", reason)
 | 
			
		||||
	for c := range clients {
 | 
			
		||||
		c.destroy()
 | 
			
		||||
@@ -151,67 +170,84 @@ func (c *client) send(line string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *client) shutdown() {
 | 
			
		||||
// Tear down the client connection, trying to do so in a graceful manner.
 | 
			
		||||
func (c *client) kill() {
 | 
			
		||||
	if c.inShutdown {
 | 
			
		||||
		log.Println("client double shutdown")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if c.conn == nil {
 | 
			
		||||
		c.destroy()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: We must set a timer and destroy the client on timeout. Since we
 | 
			
		||||
	//   have a central event loop, we probably need an event. Since we also
 | 
			
		||||
	//   seem to need an event for TLS autodetection because of conn, we might
 | 
			
		||||
	//   want to send an enumeration value.
 | 
			
		||||
	c.inShutdown = true
 | 
			
		||||
	c.conn.CloseWrite()
 | 
			
		||||
}
 | 
			
		||||
	// Since we send this goodbye, we don't need to call CloseWrite.
 | 
			
		||||
	c.send("Goodbye")
 | 
			
		||||
	c.killTimer = time.AfterFunc(3*time.Second, func() {
 | 
			
		||||
		timeouts <- c
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
// Tear down the client connection, trying to do so in a graceful manner.
 | 
			
		||||
func (c *client) kill() {
 | 
			
		||||
	if c.connReady {
 | 
			
		||||
		c.send("Goodbye")
 | 
			
		||||
		c.shutdown()
 | 
			
		||||
	} else {
 | 
			
		||||
		c.destroy()
 | 
			
		||||
	}
 | 
			
		||||
	c.inShutdown = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close the connection and forget about the client.
 | 
			
		||||
func (c *client) destroy() {
 | 
			
		||||
	// Try to send a "close notify" alert if the TLS object is ready,
 | 
			
		||||
	// otherwise just tear down the transport.
 | 
			
		||||
	if c.connReady {
 | 
			
		||||
	if c.conn != nil {
 | 
			
		||||
		_ = c.conn.Close()
 | 
			
		||||
	} else {
 | 
			
		||||
		_ = c.transport.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Clean up the goroutine, although a spurious event may still be sent.
 | 
			
		||||
	if c.killTimer != nil {
 | 
			
		||||
		c.killTimer.Stop()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	delete(clients, c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handle the results from initializing the client's connection.
 | 
			
		||||
func (c *client) onPrepared(host string, isTLS bool) {
 | 
			
		||||
	if isTLS {
 | 
			
		||||
		c.tls = tls.Server(c.transport, tlsConf)
 | 
			
		||||
		c.conn = c.tls
 | 
			
		||||
	} else {
 | 
			
		||||
		c.conn = c.transport.(connCloseWrite)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: Save the host in the client structure.
 | 
			
		||||
	go read(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handle the results from trying to read from the client connection.
 | 
			
		||||
func (c *client) onRead(data []byte, readErr error) {
 | 
			
		||||
	c.inQ = append(c.inQ, data...)
 | 
			
		||||
	for {
 | 
			
		||||
		advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
 | 
			
		||||
		c.inQ = c.inQ[advance:]
 | 
			
		||||
		if advance == 0 {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.inQ = c.inQ[advance:]
 | 
			
		||||
		line := string(token)
 | 
			
		||||
		fmt.Println(line)
 | 
			
		||||
		broadcast(line, c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: Inform the client about the inQ overrun in the farewell message.
 | 
			
		||||
	// TODO: We should stop receiving any more data from this client.
 | 
			
		||||
	if len(c.inQ) > 8192 {
 | 
			
		||||
		c.kill()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if readErr == io.EOF {
 | 
			
		||||
		// TODO: What if we're already in shutdown?
 | 
			
		||||
		c.shutdown()
 | 
			
		||||
		if c.inShutdown {
 | 
			
		||||
			c.destroy()
 | 
			
		||||
		} else {
 | 
			
		||||
			c.kill()
 | 
			
		||||
		}
 | 
			
		||||
	} else if readErr != nil {
 | 
			
		||||
		log.Println(readErr)
 | 
			
		||||
		c.destroy()
 | 
			
		||||
@@ -221,7 +257,7 @@ func (c *client) onRead(data []byte, readErr error) {
 | 
			
		||||
// Spawn a goroutine to flush the outQ if possible and necessary.  If the
 | 
			
		||||
// connection is not ready yet, it needs to be retried as soon as it becomes.
 | 
			
		||||
func (c *client) flushOutQ() {
 | 
			
		||||
	if c.connReady && !c.writing {
 | 
			
		||||
	if c.conn != nil && !c.writing {
 | 
			
		||||
		go write(c, c.outQ)
 | 
			
		||||
		c.writing = true
 | 
			
		||||
	}
 | 
			
		||||
@@ -238,41 +274,77 @@ func (c *client) onWrite(written int, writeErr error) {
 | 
			
		||||
	} else if len(c.outQ) > 0 {
 | 
			
		||||
		c.flushOutQ()
 | 
			
		||||
	} else if c.inShutdown {
 | 
			
		||||
		c.destroy()
 | 
			
		||||
		if c.conn != nil {
 | 
			
		||||
			// FIXME: This is only correct for when /we/ initiate the shutdown,
 | 
			
		||||
			// otherwise we should perhaps just Close. Though even if we
 | 
			
		||||
			// Close, there's a/ no writer to fail on it, and b/ the reader
 | 
			
		||||
			// has already exited, too, which is why the client stays alive
 | 
			
		||||
			// up until the timeout. It seems that in that case we need to
 | 
			
		||||
			// call c.destroy().
 | 
			
		||||
			c.conn.CloseWrite()
 | 
			
		||||
		} else {
 | 
			
		||||
			c.destroy()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// --- Worker goroutines -------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
func accept(ln net.Listener) {
 | 
			
		||||
	// TODO: Consider specific cases in error handling, some errors
 | 
			
		||||
	//   are transitional while others are fatal.
 | 
			
		||||
	for {
 | 
			
		||||
		if conn, err := ln.Accept(); err != nil {
 | 
			
		||||
			// TODO: Consider specific cases in error handling, some errors
 | 
			
		||||
			// are transitional while others are fatal.
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			break
 | 
			
		||||
		} else {
 | 
			
		||||
			conns <- conn
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func read(client *client) {
 | 
			
		||||
	// TODO: Either here or elsewhere we need to set a timeout.
 | 
			
		||||
 | 
			
		||||
	client.conn = client.transport.(connCloseWrite)
 | 
			
		||||
	if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil {
 | 
			
		||||
		// This is just for the TLS detection and doesn't need to be fatal.
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
	} else if detectTLS(sysconn) {
 | 
			
		||||
		client.tls = tls.Server(client.transport, tlsConf)
 | 
			
		||||
		client.conn = client.tls
 | 
			
		||||
func prepare(client *client) {
 | 
			
		||||
	conn := client.transport
 | 
			
		||||
	host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// In effect, we require TCP/UDP, as they have port numbers.
 | 
			
		||||
		log.Fatalln(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: Signal the main goroutine that conn is ready. In fact, the upper
 | 
			
		||||
	//   part could be mostly moved to the main goroutine and we'd only spawn
 | 
			
		||||
	//   a thin wrapper around detectTLS, sending back {*client, bool}. Heck,
 | 
			
		||||
	//   I could get rid of connReady.
 | 
			
		||||
	// The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not
 | 
			
		||||
	// bothering with pointless contexts.
 | 
			
		||||
	ch := make(chan string)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(ch)
 | 
			
		||||
		if names, err := net.LookupAddr(host); err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
		} else {
 | 
			
		||||
			ch <- names[0]
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// While we can't cancel it, we still want to set a timeout on it.
 | 
			
		||||
	select {
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
	case resolved, ok := <-ch:
 | 
			
		||||
		if ok {
 | 
			
		||||
			host = resolved
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	isTLS := false
 | 
			
		||||
	if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil {
 | 
			
		||||
		// This is just for the TLS detection and doesn't need to be fatal.
 | 
			
		||||
		log.Println(err)
 | 
			
		||||
	} else {
 | 
			
		||||
		isTLS = detectTLS(sysconn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	prepared <- preparedEvent{client, host, isTLS}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func read(client *client) {
 | 
			
		||||
	// A new buffer is allocated each time we receive some bytes, because of
 | 
			
		||||
	// thread-safety. Therefore the buffer shouldn't be too large, or we'd
 | 
			
		||||
	// need to copy it each time into a precisely sized new buffer.
 | 
			
		||||
@@ -312,7 +384,13 @@ func processOneEvent() {
 | 
			
		||||
		log.Println("accepted client connection")
 | 
			
		||||
		c := &client{transport: conn}
 | 
			
		||||
		clients[c] = true
 | 
			
		||||
		go read(c)
 | 
			
		||||
		go prepare(c)
 | 
			
		||||
 | 
			
		||||
	case ev := <-prepared:
 | 
			
		||||
		log.Println("client is ready:", ev.host)
 | 
			
		||||
		if _, ok := clients[ev.client]; ok {
 | 
			
		||||
			ev.client.onPrepared(ev.host, ev.isTLS)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case ev := <-reads:
 | 
			
		||||
		log.Println("received data from client")
 | 
			
		||||
@@ -325,6 +403,12 @@ func processOneEvent() {
 | 
			
		||||
		if _, ok := clients[ev.client]; ok {
 | 
			
		||||
			ev.client.onWrite(ev.written, ev.err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case c := <-timeouts:
 | 
			
		||||
		if _, ok := clients[c]; ok {
 | 
			
		||||
			log.Println("client timeouted")
 | 
			
		||||
			c.destroy()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user