haven/prototypes/tls-autodetect.go

357 lines
8.8 KiB
Go

//
// Copyright (c) 2018, Přemysl Janouch <p@janouch.name>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
//
// This is an example TLS-autodetecting chat server.
//
// You may connect to it either using:
// telnet localhost 1234
// or
// openssl s_client -connect localhost:1234
package main
import (
"bufio"
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
"syscall"
"time"
)
// --- Utilities ---------------------------------------------------------------
// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
// must be at least three octets long for this to work reliably, but that should
// not pose a problem in practice. We might try waiting for them.
//
// SSL2: 1xxx xxxx | xxxx xxxx | <1>
// (message length) (client hello)
// SSL3/TLS: <22> | <3> | xxxx xxxx
// (handshake)| (protocol version)
//
func detectTLS(sysconn syscall.RawConn) bool {
isTLS := false
sysconn.Read(func(fd uintptr) (done bool) {
var buf [3]byte
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
switch {
case n == 3:
isTLS = buf[0]&0x80 != 0 && buf[2] == 1
fallthrough
case n == 2:
isTLS = buf[0] == 22 && buf[1] == 3
case n == 1:
isTLS = buf[0] == 22
case err == syscall.EAGAIN:
return false
}
return true
})
return isTLS
}
// --- Declarations ------------------------------------------------------------
type connCloseWrite interface {
net.Conn
CloseWrite() error
}
type client struct {
transport net.Conn // underlying connection
tls *tls.Conn // TLS, if detected
conn connCloseWrite // high-level connection
connReady bool // conn is safe to read from the main goroutine
inQ []byte // unprocessed input
outQ []byte // unprocessed output
writing bool // whether a writing goroutine is running
inShutdown bool // whether we're closing connection
}
type readEvent struct {
client *client // client
data []byte // new data from the client
err error // read error
}
type writeEvent struct {
client *client // client
written int // amount of bytes written
err error // write error
}
var (
sigs = make(chan os.Signal, 1)
conns = make(chan net.Conn)
reads = make(chan readEvent)
writes = make(chan writeEvent)
tlsConf *tls.Config
clients = make(map[*client]bool)
listener net.Listener
inShutdown bool
shutdownTimer <-chan time.Time
)
// --- Server ------------------------------------------------------------------
// Broadcast to all /other/ clients (telnet-friendly, also in accordance to
// the plan of extending this to an IRCd).
func broadcast(line string, except *client) {
for c := range clients {
if c != except {
c.send(line)
}
}
}
func initiateShutdown() {
log.Println("shutting down")
if err := listener.Close(); err != nil {
log.Println(err)
}
for c := range clients {
c.kill()
}
shutdownTimer = time.After(3 * time.Second)
inShutdown = true
}
func forceShutdown(reason string) {
log.Printf("forced shutdown (%s)\n", reason)
for c := range clients {
c.destroy()
}
}
// --- Client ------------------------------------------------------------------
func (c *client) send(line string) {
if !c.inShutdown {
c.outQ = append(c.outQ, (line + "\r\n")...)
c.flushOutQ()
}
}
func (c *client) shutdown() {
if c.inShutdown {
log.Println("client double shutdown")
return
}
// TODO: We must set a timer and destroy the client on timeout. Since we
// have a central event loop, we probably need an event. Since we also
// seem to need an event for TLS autodetection because of conn, we might
// want to send an enumeration value.
c.inShutdown = true
c.conn.CloseWrite()
}
// Tear down the client connection, trying to do so in a graceful manner.
func (c *client) kill() {
if c.connReady {
c.send("Goodbye")
c.shutdown()
} else {
c.destroy()
}
}
// Close the connection and forget about the client.
func (c *client) destroy() {
// Try to send a "close notify" alert if the TLS object is ready,
// otherwise just tear down the transport.
if c.connReady {
_ = c.conn.Close()
} else {
_ = c.transport.Close()
}
delete(clients, c)
}
// Handle the results from trying to read from the client connection.
func (c *client) onRead(data []byte, readErr error) {
c.inQ = append(c.inQ, data...)
for {
advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
c.inQ = c.inQ[advance:]
if advance == 0 {
break
}
line := string(token)
fmt.Println(line)
broadcast(line, c)
}
// TODO: Inform the client about the inQ overrun in the farewell message.
if len(c.inQ) > 8192 {
c.kill()
return
}
if readErr == io.EOF {
// TODO: What if we're already in shutdown?
c.shutdown()
} else if readErr != nil {
log.Println(readErr)
c.destroy()
}
}
// Spawn a goroutine to flush the outQ if possible and necessary. If the
// connection is not ready yet, it needs to be retried as soon as it becomes.
func (c *client) flushOutQ() {
if c.connReady && !c.writing {
go write(c, c.outQ)
c.writing = true
}
}
// Handle the results from trying to write to the client connection.
func (c *client) onWrite(written int, writeErr error) {
c.outQ = c.outQ[written:]
c.writing = false
if writeErr != nil {
log.Println(writeErr)
c.destroy()
} else if len(c.outQ) > 0 {
c.flushOutQ()
} else if c.inShutdown {
c.destroy()
}
}
// --- Worker goroutines -------------------------------------------------------
func accept(ln net.Listener) {
// TODO: Consider specific cases in error handling, some errors
// are transitional while others are fatal.
for {
if conn, err := ln.Accept(); err != nil {
log.Println(err)
} else {
conns <- conn
}
}
}
func read(client *client) {
// TODO: Either here or elsewhere we need to set a timeout.
client.conn = client.transport.(connCloseWrite)
if sysconn, err := client.transport.(syscall.Conn).SyscallConn(); err != nil {
// This is just for the TLS detection and doesn't need to be fatal.
log.Println(err)
} else if detectTLS(sysconn) {
client.tls = tls.Server(client.transport, tlsConf)
client.conn = client.tls
}
// TODO: Signal the main goroutine that conn is ready. In fact, the upper
// part could be mostly moved to the main goroutine and we'd only spawn
// a thin wrapper around detectTLS, sending back {*client, bool}. Heck,
// I could get rid of connReady.
// A new buffer is allocated each time we receive some bytes, because of
// thread-safety. Therefore the buffer shouldn't be too large, or we'd
// need to copy it each time into a precisely sized new buffer.
var err error
for err == nil {
var (
buf [512]byte
n int
)
n, err = client.conn.Read(buf[:])
reads <- readEvent{client, buf[:n], err}
}
}
// Flush outQ, which is passed by parameter so that there are no data races.
func write(client *client, data []byte) {
// We just write as much as we can, the main goroutine does the looping.
n, err := client.conn.Write(data)
writes <- writeEvent{client, n, err}
}
// --- Main --------------------------------------------------------------------
func processOneEvent() {
select {
case <-sigs:
if inShutdown {
forceShutdown("requested by user")
} else {
initiateShutdown()
}
case <-shutdownTimer:
forceShutdown("timeout")
case conn := <-conns:
log.Println("accepted client connection")
c := &client{transport: conn}
clients[c] = true
go read(c)
case ev := <-reads:
log.Println("received data from client")
if _, ok := clients[ev.client]; ok {
ev.client.onRead(ev.data, ev.err)
}
case ev := <-writes:
log.Println("sent data to client")
if _, ok := clients[ev.client]; ok {
ev.client.onWrite(ev.written, ev.err)
}
}
}
func main() {
// Just deal with unexpected flags, we don't use any ourselves.
flag.Parse()
if len(flag.Args()) != 3 {
log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0])
}
cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0))
if err != nil {
log.Fatalln(err)
}
tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}}
listener, err = net.Listen("tcp", flag.Arg(2))
if err != nil {
log.Fatalln(err)
}
go accept(listener)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
for !inShutdown || len(clients) > 0 {
processOneEvent()
}
}