// // Copyright (c) 2018, Přemysl Janouch
// // Permission to use, copy, modify, and/or distribute this software for any // purpose with or without fee is hereby granted. // // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // // This is an example TLS-autodetecting chat server. // // You may connect to it either using: // telnet localhost 1234 // or // openssl s_client -connect localhost:1234 package main import ( "bufio" "crypto/tls" "flag" "fmt" "io" "log" "net" "os" "os/signal" "syscall" "time" ) // --- Utilities --------------------------------------------------------------- // Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom // must be at least three octets long for this to work reliably, but that should // not pose a problem in practice. We might try waiting for them. // // SSL2: 1xxx xxxx | xxxx xxxx | <1> // (message length) (client hello) // SSL3/TLS: <22> | <3> | xxxx xxxx // (handshake)| (protocol version) // func detectTLS(sysconn syscall.RawConn) bool { isTLS := false sysconn.Read(func(fd uintptr) (done bool) { var buf [3]byte n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK) switch { case n == 3: isTLS = buf[0]&0x80 != 0 && buf[2] == 1 fallthrough case n == 2: isTLS = buf[0] == 22 && buf[1] == 3 case n == 1: isTLS = buf[0] == 22 case err == syscall.EAGAIN: return false } return true }) return isTLS } // --- Declarations ------------------------------------------------------------ type connCloseWrite interface { net.Conn CloseWrite() error } type client struct { transport net.Conn // underlying connection tls *tls.Conn // TLS, if detected conn connCloseWrite // high-level connection connReady bool // conn is safe to read from the main goroutine inQ []byte // unprocessed input outQ []byte // unprocessed output writing bool // whether a writing goroutine is running inShutdown bool // whether we're closing connection } type readEvent struct { client *client // client data []byte // new data from the client err error // read error } type writeEvent struct { client *client // client written int // amount of bytes written err error // write error } var ( sigs = make(chan os.Signal, 1) conns = make(chan net.Conn) reads = make(chan readEvent) writes = make(chan writeEvent) tlsConf *tls.Config clients = make(map[*client]bool) listener net.Listener inShutdown bool shutdownTimer <-chan time.Time ) // --- Server ------------------------------------------------------------------ // Broadcast to all /other/ clients (telnet-friendly, also in accordance to // the plan of extending this to an IRCd). func broadcast(line string, except *client) { for c := range clients { if c != except { c.send(line) } } } func initiateShutdown() { log.Println("shutting down") if err := listener.Close(); err != nil { log.Println(err) } for c := range clients { c.kill() } shutdownTimer = time.After(3 * time.Second) inShutdown = true } func forceShutdown(reason string) { log.Printf("forced shutdown (%s)\n", reason) for c := range clients { c.destroy() } } // --- Client ------------------------------------------------------------------ func (c *client) send(line string) { if !c.inShutdown { c.outQ = append(c.outQ, (line + "\r\n")...) c.flushOutQ() } } func (c *client) shutdown() { if c.inShutdown { log.Println("client double shutdown") return } // TODO: We must set a timer and destroy the client on timeout. Since we // have a central event loop, we probably need an event. Since we also // seem to need an event for TLS autodetection because of conn, we might // want to send an enumeration value. c.inShutdown = true c.conn.CloseWrite() } // Tear down the client connection, trying to do so in a graceful manner. func (c *client) kill() { if c.connReady { c.send("Goodbye") c.shutdown() } else { c.destroy() } } // Close the connection and forget about the client. func (c *client) destroy() { // Try to send a "close notify" alert if the TLS object is ready, // otherwise just tear down the transport. if c.connReady { _ = c.conn.Close() } else { _ = c.transport.Close() } delete(clients, c) } // Handle the results from trying to read from the client connection. func (c *client) onRead(data []byte, readErr error) { c.inQ = append(c.inQ, data...) for { advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */) c.inQ = c.inQ[advance:] if advance == 0 { break } line := string(token) fmt.Println(line) broadcast(line, c) } // TODO: Inform the client about the inQ overrun in the farewell message. if len(c.inQ) > 8192 { c.kill() return } if readErr == io.EOF { // TODO: What if we're already in shutdown? c.shutdown() } else if readErr != nil { log.Println(readErr) c.destroy() } } // Spawn a goroutine to flush the outQ if possible and necessary. If the // connection is not ready yet, it needs to be retried as soon as it becomes. func (c *client) flushOutQ() { if c.connReady && !c.writing { go write(c, c.outQ) c.writing = true } } // Handle the results from trying to write to the client connection. func (c *client) onWrite(written int, writeErr error) { c.outQ = c.outQ[written:] c.writing = false if writeErr != nil { log.Println(writeErr) c.destroy() } else if len(c.outQ) > 0 { c.flushOutQ() } else if c.inShutdown { c.destroy() } } // --- Worker goroutines ------------------------------------------------------- func accept(ln net.Listener) { // TODO: Consider specific cases in error handling, some errors // are transitional while others are fatal. for { if conn, err := ln.Accept(); err != nil { log.Println(err) } else { conns <- conn } } } func read(client *client) { // TODO: Either here or elsewhere we need to set a timeout. client.conn = client.transport.(connCloseWrite) if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil { // This is just for the TLS detection and doesn't need to be fatal. log.Println(err) } else if detectTLS(sysconn) { client.tls = tls.Server(client.transport, tlsConf) client.conn = client.tls } // TODO: Signal the main goroutine that conn is ready. In fact, the upper // part could be mostly moved to the main goroutine and we'd only spawn // a thin wrapper around detectTLS, sending back {*client, bool}. Heck, // I could get rid of connReady. // A new buffer is allocated each time we receive some bytes, because of // thread-safety. Therefore the buffer shouldn't be too large, or we'd // need to copy it each time into a precisely sized new buffer. var err error for err == nil { var ( buf [512]byte n int ) n, err = client.conn.Read(buf[:]) reads <- readEvent{client, buf[:n], err} } } // Flush outQ, which is passed by parameter so that there are no data races. func write(client *client, data []byte) { // We just write as much as we can, the main goroutine does the looping. n, err := client.conn.Write(data) writes <- writeEvent{client, n, err} } // --- Main -------------------------------------------------------------------- func processOneEvent() { select { case <-sigs: if inShutdown { forceShutdown("requested by user") } else { initiateShutdown() } case <-shutdownTimer: forceShutdown("timeout") case conn := <-conns: log.Println("accepted client connection") c := &client{transport: conn} clients[c] = true go read(c) case ev := <-reads: log.Println("received data from client") if _, ok := clients[ev.client]; ok { ev.client.onRead(ev.data, ev.err) } case ev := <-writes: log.Println("sent data to client") if _, ok := clients[ev.client]; ok { ev.client.onWrite(ev.written, ev.err) } } } func main() { // Just deal with unexpected flags, we don't use any ourselves. flag.Parse() if len(flag.Args()) != 3 { log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0]) } cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0)) if err != nil { log.Fatalln(err) } tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}} listener, err = net.Listen("tcp", flag.Arg(2)) if err != nil { log.Fatalln(err) } go accept(listener) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) for !inShutdown || len(clients) > 0 { processOneEvent() } }