tls-autodetect: put most of the server code in place
So far we act up when it is the client who initializes the shutdown.
This commit is contained in:
		
							parent
							
								
									b5b64db075
								
							
						
					
					
						commit
						728fa4e548
					
				| @ -13,12 +13,17 @@ | ||||
| // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | ||||
| // | ||||
| 
 | ||||
| // | ||||
| // This is an example TLS-autodetecting chat server. | ||||
| // | ||||
| // You may connect to it either using: | ||||
| // You may connect to it using either of these: | ||||
| //  ncat -C localhost 1234 | ||||
| //  ncat -C --ssl localhost 1234 | ||||
| // | ||||
| // These clients are unable to properly shutdown the connection: | ||||
| //  telnet localhost 1234 | ||||
| // or | ||||
| //  openssl s_client -connect localhost:1234 | ||||
| // | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| @ -37,6 +42,7 @@ import ( | ||||
| 
 | ||||
| // --- Utilities --------------------------------------------------------------- | ||||
| 
 | ||||
| // | ||||
| // Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom | ||||
| // must be at least three octets long for this to work reliably, but that should | ||||
| // not pose a problem in practice. We might try waiting for them. | ||||
| @ -46,8 +52,7 @@ import ( | ||||
| //  SSL3/TLS:    <22>    |    <3>    | xxxx xxxx | ||||
| //            (handshake)|  (protocol version) | ||||
| // | ||||
| func detectTLS(sysconn syscall.RawConn) bool { | ||||
| 	isTLS := false | ||||
| func detectTLS(sysconn syscall.RawConn) (isTLS bool) { | ||||
| 	sysconn.Read(func(fd uintptr) (done bool) { | ||||
| 		var buf [3]byte | ||||
| 		n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) | ||||
| @ -78,30 +83,38 @@ type client struct { | ||||
| 	transport  net.Conn       // underlying connection | ||||
| 	tls        *tls.Conn      // TLS, if detected | ||||
| 	conn       connCloseWrite // high-level connection | ||||
| 	connReady  bool           // conn is safe to read from the main goroutine | ||||
| 	inQ        []byte         // unprocessed input | ||||
| 	outQ       []byte         // unprocessed output | ||||
| 	writing    bool           // whether a writing goroutine is running | ||||
| 	inShutdown bool           // whether we're closing connection | ||||
| 	killTimer  *time.Timer    // timeout | ||||
| } | ||||
| 
 | ||||
| type preparedEvent struct { | ||||
| 	client *client | ||||
| 	host   string // client's hostname or literal IP address | ||||
| 	isTLS  bool   // the client seems to use TLS | ||||
| } | ||||
| 
 | ||||
| type readEvent struct { | ||||
| 	client *client // client | ||||
| 	data   []byte  // new data from the client | ||||
| 	err    error   // read error | ||||
| 	client *client | ||||
| 	data   []byte // new data from the client | ||||
| 	err    error  // read error | ||||
| } | ||||
| 
 | ||||
| type writeEvent struct { | ||||
| 	client  *client // client | ||||
| 	written int     // amount of bytes written | ||||
| 	err     error   // write error | ||||
| 	client  *client | ||||
| 	written int   // amount of bytes written | ||||
| 	err     error // write error | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	sigs   = make(chan os.Signal, 1) | ||||
| 	conns  = make(chan net.Conn) | ||||
| 	reads  = make(chan readEvent) | ||||
| 	writes = make(chan writeEvent) | ||||
| 	sigs     = make(chan os.Signal, 1) | ||||
| 	conns    = make(chan net.Conn) | ||||
| 	prepared = make(chan preparedEvent) | ||||
| 	reads    = make(chan readEvent) | ||||
| 	writes   = make(chan writeEvent) | ||||
| 	timeouts = make(chan *client) | ||||
| 
 | ||||
| 	tlsConf       *tls.Config | ||||
| 	clients       = make(map[*client]bool) | ||||
| @ -122,6 +135,7 @@ func broadcast(line string, except *client) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Initiate a clean shutdown of the whole daemon. | ||||
| func initiateShutdown() { | ||||
| 	log.Println("shutting down") | ||||
| 	if err := listener.Close(); err != nil { | ||||
| @ -135,7 +149,12 @@ func initiateShutdown() { | ||||
| 	inShutdown = true | ||||
| } | ||||
| 
 | ||||
| // Forcefully tear down all connections. | ||||
| func forceShutdown(reason string) { | ||||
| 	if !inShutdown { | ||||
| 		log.Fatalln("forceShutdown called without initiateShutdown") | ||||
| 	} | ||||
| 
 | ||||
| 	log.Printf("forced shutdown (%s)\n", reason) | ||||
| 	for c := range clients { | ||||
| 		c.destroy() | ||||
| @ -151,67 +170,84 @@ func (c *client) send(line string) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (c *client) shutdown() { | ||||
| // Tear down the client connection, trying to do so in a graceful manner. | ||||
| func (c *client) kill() { | ||||
| 	if c.inShutdown { | ||||
| 		log.Println("client double shutdown") | ||||
| 		return | ||||
| 	} | ||||
| 	if c.conn == nil { | ||||
| 		c.destroy() | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: We must set a timer and destroy the client on timeout. Since we | ||||
| 	//   have a central event loop, we probably need an event. Since we also | ||||
| 	//   seem to need an event for TLS autodetection because of conn, we might | ||||
| 	//   want to send an enumeration value. | ||||
| 	c.inShutdown = true | ||||
| 	c.conn.CloseWrite() | ||||
| } | ||||
| 	// Since we send this goodbye, we don't need to call CloseWrite. | ||||
| 	c.send("Goodbye") | ||||
| 	c.killTimer = time.AfterFunc(3*time.Second, func() { | ||||
| 		timeouts <- c | ||||
| 	}) | ||||
| 
 | ||||
| // Tear down the client connection, trying to do so in a graceful manner. | ||||
| func (c *client) kill() { | ||||
| 	if c.connReady { | ||||
| 		c.send("Goodbye") | ||||
| 		c.shutdown() | ||||
| 	} else { | ||||
| 		c.destroy() | ||||
| 	} | ||||
| 	c.inShutdown = true | ||||
| } | ||||
| 
 | ||||
| // Close the connection and forget about the client. | ||||
| func (c *client) destroy() { | ||||
| 	// Try to send a "close notify" alert if the TLS object is ready, | ||||
| 	// otherwise just tear down the transport. | ||||
| 	if c.connReady { | ||||
| 	if c.conn != nil { | ||||
| 		_ = c.conn.Close() | ||||
| 	} else { | ||||
| 		_ = c.transport.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	// Clean up the goroutine, although a spurious event may still be sent. | ||||
| 	if c.killTimer != nil { | ||||
| 		c.killTimer.Stop() | ||||
| 	} | ||||
| 
 | ||||
| 	delete(clients, c) | ||||
| } | ||||
| 
 | ||||
| // Handle the results from initializing the client's connection. | ||||
| func (c *client) onPrepared(host string, isTLS bool) { | ||||
| 	if isTLS { | ||||
| 		c.tls = tls.Server(c.transport, tlsConf) | ||||
| 		c.conn = c.tls | ||||
| 	} else { | ||||
| 		c.conn = c.transport.(connCloseWrite) | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: Save the host in the client structure. | ||||
| 	go read(c) | ||||
| } | ||||
| 
 | ||||
| // Handle the results from trying to read from the client connection. | ||||
| func (c *client) onRead(data []byte, readErr error) { | ||||
| 	c.inQ = append(c.inQ, data...) | ||||
| 	for { | ||||
| 		advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) | ||||
| 		c.inQ = c.inQ[advance:] | ||||
| 		if advance == 0 { | ||||
| 			break | ||||
| 		} | ||||
| 
 | ||||
| 		c.inQ = c.inQ[advance:] | ||||
| 		line := string(token) | ||||
| 		fmt.Println(line) | ||||
| 		broadcast(line, c) | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: Inform the client about the inQ overrun in the farewell message. | ||||
| 	// TODO: We should stop receiving any more data from this client. | ||||
| 	if len(c.inQ) > 8192 { | ||||
| 		c.kill() | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if readErr == io.EOF { | ||||
| 		// TODO: What if we're already in shutdown? | ||||
| 		c.shutdown() | ||||
| 		if c.inShutdown { | ||||
| 			c.destroy() | ||||
| 		} else { | ||||
| 			c.kill() | ||||
| 		} | ||||
| 	} else if readErr != nil { | ||||
| 		log.Println(readErr) | ||||
| 		c.destroy() | ||||
| @ -221,7 +257,7 @@ func (c *client) onRead(data []byte, readErr error) { | ||||
| // Spawn a goroutine to flush the outQ if possible and necessary.  If the | ||||
| // connection is not ready yet, it needs to be retried as soon as it becomes. | ||||
| func (c *client) flushOutQ() { | ||||
| 	if c.connReady && !c.writing { | ||||
| 	if c.conn != nil && !c.writing { | ||||
| 		go write(c, c.outQ) | ||||
| 		c.writing = true | ||||
| 	} | ||||
| @ -238,41 +274,77 @@ func (c *client) onWrite(written int, writeErr error) { | ||||
| 	} else if len(c.outQ) > 0 { | ||||
| 		c.flushOutQ() | ||||
| 	} else if c.inShutdown { | ||||
| 		c.destroy() | ||||
| 		if c.conn != nil { | ||||
| 			// FIXME: This is only correct for when /we/ initiate the shutdown, | ||||
| 			// otherwise we should perhaps just Close. Though even if we | ||||
| 			// Close, there's a/ no writer to fail on it, and b/ the reader | ||||
| 			// has already exited, too, which is why the client stays alive | ||||
| 			// up until the timeout. It seems that in that case we need to | ||||
| 			// call c.destroy(). | ||||
| 			c.conn.CloseWrite() | ||||
| 		} else { | ||||
| 			c.destroy() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // --- Worker goroutines ------------------------------------------------------- | ||||
| 
 | ||||
| func accept(ln net.Listener) { | ||||
| 	// TODO: Consider specific cases in error handling, some errors | ||||
| 	//   are transitional while others are fatal. | ||||
| 	for { | ||||
| 		if conn, err := ln.Accept(); err != nil { | ||||
| 			// TODO: Consider specific cases in error handling, some errors | ||||
| 			// are transitional while others are fatal. | ||||
| 			log.Println(err) | ||||
| 			break | ||||
| 		} else { | ||||
| 			conns <- conn | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func read(client *client) { | ||||
| 	// TODO: Either here or elsewhere we need to set a timeout. | ||||
| 
 | ||||
| 	client.conn = client.transport.(connCloseWrite) | ||||
| 	if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil { | ||||
| 		// This is just for the TLS detection and doesn't need to be fatal. | ||||
| 		log.Println(err) | ||||
| 	} else if detectTLS(sysconn) { | ||||
| 		client.tls = tls.Server(client.transport, tlsConf) | ||||
| 		client.conn = client.tls | ||||
| func prepare(client *client) { | ||||
| 	conn := client.transport | ||||
| 	host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) | ||||
| 	if err != nil { | ||||
| 		// In effect, we require TCP/UDP, as they have port numbers. | ||||
| 		log.Fatalln(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: Signal the main goroutine that conn is ready. In fact, the upper | ||||
| 	//   part could be mostly moved to the main goroutine and we'd only spawn | ||||
| 	//   a thin wrapper around detectTLS, sending back {*client, bool}. Heck, | ||||
| 	//   I could get rid of connReady. | ||||
| 	// The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not | ||||
| 	// bothering with pointless contexts. | ||||
| 	ch := make(chan string) | ||||
| 	go func() { | ||||
| 		defer close(ch) | ||||
| 		if names, err := net.LookupAddr(host); err != nil { | ||||
| 			log.Println(err) | ||||
| 		} else { | ||||
| 			ch <- names[0] | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	// While we can't cancel it, we still want to set a timeout on it. | ||||
| 	select { | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 	case resolved, ok := <-ch: | ||||
| 		if ok { | ||||
| 			host = resolved | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	isTLS := false | ||||
| 	if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil { | ||||
| 		// This is just for the TLS detection and doesn't need to be fatal. | ||||
| 		log.Println(err) | ||||
| 	} else { | ||||
| 		isTLS = detectTLS(sysconn) | ||||
| 	} | ||||
| 
 | ||||
| 	prepared <- preparedEvent{client, host, isTLS} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func read(client *client) { | ||||
| 	// A new buffer is allocated each time we receive some bytes, because of | ||||
| 	// thread-safety. Therefore the buffer shouldn't be too large, or we'd | ||||
| 	// need to copy it each time into a precisely sized new buffer. | ||||
| @ -312,7 +384,13 @@ func processOneEvent() { | ||||
| 		log.Println("accepted client connection") | ||||
| 		c := &client{transport: conn} | ||||
| 		clients[c] = true | ||||
| 		go read(c) | ||||
| 		go prepare(c) | ||||
| 
 | ||||
| 	case ev := <-prepared: | ||||
| 		log.Println("client is ready:", ev.host) | ||||
| 		if _, ok := clients[ev.client]; ok { | ||||
| 			ev.client.onPrepared(ev.host, ev.isTLS) | ||||
| 		} | ||||
| 
 | ||||
| 	case ev := <-reads: | ||||
| 		log.Println("received data from client") | ||||
| @ -325,6 +403,12 @@ func processOneEvent() { | ||||
| 		if _, ok := clients[ev.client]; ok { | ||||
| 			ev.client.onWrite(ev.written, ev.err) | ||||
| 		} | ||||
| 
 | ||||
| 	case c := <-timeouts: | ||||
| 		if _, ok := clients[c]; ok { | ||||
| 			log.Println("client timeouted") | ||||
| 			c.destroy() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user