// // Copyright (c) 2018, Přemysl Janouch
// // Permission to use, copy, modify, and/or distribute this software for any // purpose with or without fee is hereby granted. // // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // // // This is an example TLS-autodetecting chat server. // // These clients are unable to properly shutdown the connection on their exit: // telnet localhost 1234 // openssl s_client -connect localhost:1234 // // While this one doesn't react to an EOF from the server: // ncat -C localhost 1234 // ncat -C --ssl localhost 1234 // package main import ( "bufio" "crypto/tls" "flag" "fmt" "io" "log" "net" "os" "os/signal" "syscall" "time" ) // --- Utilities --------------------------------------------------------------- // // Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom // must be at least three octets long for this to work reliably, but that should // not pose a problem in practice. We might try waiting for them. // // SSL2: 1xxx xxxx | xxxx xxxx | <1> // (message length) (client hello) // SSL3/TLS: <22> | <3> | xxxx xxxx // (handshake)| (protocol version) // func detectTLS(sysconn syscall.RawConn) (isTLS bool) { sysconn.Read(func(fd uintptr) (done bool) { var buf [3]byte n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) switch { case n == 3: isTLS = buf[0]&0x80 != 0 && buf[2] == 1 fallthrough case n == 2: isTLS = isTLS || buf[0] == 22 && buf[1] == 3 case n == 1: isTLS = buf[0] == 22 case err == syscall.EAGAIN: return false } return true }) return isTLS } // --- Declarations ------------------------------------------------------------ type connCloseWriter interface { net.Conn CloseWrite() error } type client struct { transport net.Conn // underlying connection tls *tls.Conn // TLS, if detected conn connCloseWriter // high-level connection inQ []byte // unprocessed input outQ []byte // unprocessed output reading bool // whether a reading goroutine is running writing bool // whether a writing goroutine is running closing bool // whether we're closing the connection killTimer *time.Timer // timeout } type preparedEvent struct { client *client host string // client's hostname or literal IP address isTLS bool // the client seems to use TLS } type readEvent struct { client *client data []byte // new data from the client err error // read error } type writeEvent struct { client *client written int // amount of bytes written err error // write error } var ( sigs = make(chan os.Signal, 1) conns = make(chan net.Conn) prepared = make(chan preparedEvent) reads = make(chan readEvent) writes = make(chan writeEvent) timeouts = make(chan *client) tlsConf *tls.Config clients = make(map[*client]bool) listener net.Listener inShutdown bool shutdownTimer <-chan time.Time ) // --- Server ------------------------------------------------------------------ // Broadcast to all /other/ clients (telnet-friendly, also in accordance to // the plan of extending this to an IRCd). func broadcast(line string, except *client) { for c := range clients { if c != except { c.send(line) } } } // Initiate a clean shutdown of the whole daemon. func initiateShutdown() { log.Println("shutting down") if err := listener.Close(); err != nil { log.Println(err) } for c := range clients { c.closeLink() } shutdownTimer = time.After(3 * time.Second) inShutdown = true } // Forcefully tear down all connections. func forceShutdown(reason string) { if !inShutdown { log.Fatalln("forceShutdown called without initiateShutdown") } log.Printf("forced shutdown (%s)\n", reason) for c := range clients { c.destroy() } } // --- Client ------------------------------------------------------------------ func (c *client) send(line string) { if c.conn != nil && !c.closing { c.outQ = append(c.outQ, (line + "\r\n")...) c.flushOutQ() } } // Tear down the client connection, trying to do so in a graceful manner. func (c *client) closeLink() { if c.closing { return } if c.conn == nil { c.destroy() return } // Since we send this goodbye, we don't need to call CloseWrite here. c.send("Goodbye") c.killTimer = time.AfterFunc(3*time.Second, func() { timeouts <- c }) c.closing = true } // Close the connection and forget about the client. func (c *client) destroy() { // Try to send a "close notify" alert if the TLS object is ready, // otherwise just tear down the transport. if c.conn != nil { _ = c.conn.Close() } else { _ = c.transport.Close() } // Clean up the goroutine, although a spurious event may still be sent. if c.killTimer != nil { c.killTimer.Stop() } log.Println("client destroyed") delete(clients, c) } // Handle the results from initializing the client's connection. func (c *client) onPrepared(isTLS bool) { if isTLS { c.tls = tls.Server(c.transport, tlsConf) c.conn = c.tls } else { c.conn = c.transport.(connCloseWriter) } // TODO: If we've tried to send any data before now, we need to flushOutQ. go read(c) c.reading = true } // Handle the results from trying to read from the client connection. func (c *client) onRead(data []byte, readErr error) { if !c.reading { // Abusing the flag to emulate CloseRead and skip over data, see below. return } c.inQ = append(c.inQ, data...) for { advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) if advance == 0 { break } c.inQ = c.inQ[advance:] line := string(token) fmt.Println(line) broadcast(line, c) } if readErr != nil { c.reading = false if readErr != io.EOF { log.Println(readErr) c.destroy() } else if c.closing { // Disregarding whether a clean shutdown has happened or not. log.Println("client finished shutdown") c.destroy() } else { log.Println("client EOF") c.closeLink() } } else if len(c.inQ) > 8192 { log.Println("client inQ overrun") // TODO: Inform the client about inQ overrun in the farewell message. c.closeLink() // tls.Conn doesn't have the CloseRead method (and it needs to be able // to read from the TCP connection even for writes, so there isn't much // sense in expecting the implementation to do anything useful), // otherwise we'd use it to block incoming packet data. c.reading = false } } // Spawn a goroutine to flush the outQ if possible and necessary. func (c *client) flushOutQ() { if !c.writing && c.conn != nil { go write(c, c.outQ) c.writing = true } } // Handle the results from trying to write to the client connection. func (c *client) onWrite(written int, writeErr error) { c.outQ = c.outQ[written:] c.writing = false if writeErr != nil { log.Println(writeErr) c.destroy() } else if len(c.outQ) > 0 { c.flushOutQ() } else if c.closing { if c.reading { c.conn.CloseWrite() } else { c.destroy() } } } // --- Worker goroutines ------------------------------------------------------- func accept(ln net.Listener) { for { if conn, err := ln.Accept(); err != nil { // TODO: Consider specific cases in error handling, some errors // are transitional while others are fatal. log.Println(err) break } else { conns <- conn } } } func prepare(client *client) { conn := client.transport host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { // In effect, we require TCP/UDP, as they have port numbers. log.Fatalln(err) } // The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not // bothering with pointless contexts. ch := make(chan string, 1) go func() { defer close(ch) if names, err := net.LookupAddr(host); err != nil { log.Println(err) } else { ch <- names[0] } }() // While we can't cancel it, we still want to set a timeout on it. select { case <-time.After(5 * time.Second): case resolved, ok := <-ch: if ok { host = resolved } } // Note that in this demo application the autodetection prevents non-TLS // clients from receiving any messages until they send something. isTLS := false if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil { // This is just for the TLS detection and doesn't need to be fatal. log.Println(err) } else { isTLS = detectTLS(sysconn) } // FIXME: When the client sends no data, we still initialize its conn. prepared <- preparedEvent{client, host, isTLS} } func read(client *client) { // A new buffer is allocated each time we receive some bytes, because of // thread-safety. Therefore the buffer shouldn't be too large, or we'd // need to copy it each time into a precisely sized new buffer. var err error for err == nil { var ( buf [512]byte n int ) n, err = client.conn.Read(buf[:]) reads <- readEvent{client, buf[:n], err} } } // Flush outQ, which is passed by parameter so that there are no data races. func write(client *client, data []byte) { // We just write as much as we can, the main goroutine does the looping. n, err := client.conn.Write(data) writes <- writeEvent{client, n, err} } // --- Main -------------------------------------------------------------------- func processOneEvent() { select { case <-sigs: if inShutdown { forceShutdown("requested by user") } else { initiateShutdown() } case <-shutdownTimer: forceShutdown("timeout") case conn := <-conns: log.Println("accepted client connection") c := &client{transport: conn} clients[c] = true go prepare(c) case ev := <-prepared: log.Println("client is ready, resolved to", ev.host) if _, ok := clients[ev.client]; ok { ev.client.onPrepared(ev.isTLS) } case ev := <-reads: log.Println("received data from client") if _, ok := clients[ev.client]; ok { ev.client.onRead(ev.data, ev.err) } case ev := <-writes: log.Println("sent data to client") if _, ok := clients[ev.client]; ok { ev.client.onWrite(ev.written, ev.err) } case c := <-timeouts: if _, ok := clients[c]; ok { log.Println("client timeouted") c.destroy() } } } func main() { // Just deal with unexpected flags, we don't use any ourselves. flag.Parse() if len(flag.Args()) != 3 { log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0]) } cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0)) if err != nil { log.Fatalln(err) } tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}} listener, err = net.Listen("tcp", flag.Arg(2)) if err != nil { log.Fatalln(err) } go accept(listener) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) for !inShutdown || len(clients) > 0 { processOneEvent() } }