452 lines
11 KiB
Go
452 lines
11 KiB
Go
//
|
|
// Copyright (c) 2018, Přemysl Eric Janouch <p@janouch.name>
|
|
//
|
|
// Permission to use, copy, modify, and/or distribute this software for any
|
|
// purpose with or without fee is hereby granted.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
|
|
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
|
|
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
|
|
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
//
|
|
|
|
//
|
|
// This is an example TLS-autodetecting chat server.
|
|
//
|
|
// These clients are unable to properly shutdown the connection on their exit:
|
|
// telnet localhost 1234
|
|
// openssl s_client -connect localhost:1234
|
|
//
|
|
// While this one doesn't react to an EOF from the server:
|
|
// ncat -C localhost 1234
|
|
// ncat -C --ssl localhost 1234
|
|
//
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
// --- Utilities ---------------------------------------------------------------
|
|
|
|
//
|
|
// Trivial SSL/TLS autodetection. The first block of data returned by Recvfrom
|
|
// must be at least three octets long for this to work reliably, but that should
|
|
// not pose a problem in practice. We might try waiting for them.
|
|
//
|
|
// SSL2: 1xxx xxxx | xxxx xxxx | <1>
|
|
// (message length) (client hello)
|
|
// SSL3/TLS: <22> | <3> | xxxx xxxx
|
|
// (handshake)| (protocol version)
|
|
//
|
|
func detectTLS(sysconn syscall.RawConn) (isTLS bool) {
|
|
sysconn.Read(func(fd uintptr) (done bool) {
|
|
var buf [3]byte
|
|
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK)
|
|
switch {
|
|
case n == 3:
|
|
isTLS = buf[0]&0x80 != 0 && buf[2] == 1
|
|
fallthrough
|
|
case n == 2:
|
|
isTLS = isTLS || buf[0] == 22 && buf[1] == 3
|
|
case n == 1:
|
|
isTLS = buf[0] == 22
|
|
case err == syscall.EAGAIN:
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return isTLS
|
|
}
|
|
|
|
// --- Declarations ------------------------------------------------------------
|
|
|
|
type connCloseWriter interface {
|
|
net.Conn
|
|
CloseWrite() error
|
|
}
|
|
|
|
type client struct {
|
|
transport net.Conn // underlying connection
|
|
tls *tls.Conn // TLS, if detected
|
|
conn connCloseWriter // high-level connection
|
|
inQ []byte // unprocessed input
|
|
outQ []byte // unprocessed output
|
|
reading bool // whether a reading goroutine is running
|
|
writing bool // whether a writing goroutine is running
|
|
closing bool // whether we're closing the connection
|
|
killTimer *time.Timer // timeout
|
|
}
|
|
|
|
type preparedEvent struct {
|
|
client *client
|
|
host string // client's hostname or literal IP address
|
|
isTLS bool // the client seems to use TLS
|
|
}
|
|
|
|
type readEvent struct {
|
|
client *client
|
|
data []byte // new data from the client
|
|
err error // read error
|
|
}
|
|
|
|
type writeEvent struct {
|
|
client *client
|
|
written int // amount of bytes written
|
|
err error // write error
|
|
}
|
|
|
|
var (
|
|
sigs = make(chan os.Signal, 1)
|
|
conns = make(chan net.Conn)
|
|
prepared = make(chan preparedEvent)
|
|
reads = make(chan readEvent)
|
|
writes = make(chan writeEvent)
|
|
timeouts = make(chan *client)
|
|
|
|
tlsConf *tls.Config
|
|
clients = make(map[*client]bool)
|
|
listener net.Listener
|
|
inShutdown bool
|
|
shutdownTimer <-chan time.Time
|
|
)
|
|
|
|
// --- Server ------------------------------------------------------------------
|
|
|
|
// Broadcast to all /other/ clients (telnet-friendly, also in accordance to
|
|
// the plan of extending this to an IRCd).
|
|
func broadcast(line string, except *client) {
|
|
for c := range clients {
|
|
if c != except {
|
|
c.send(line)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Initiate a clean shutdown of the whole daemon.
|
|
func initiateShutdown() {
|
|
log.Println("shutting down")
|
|
if err := listener.Close(); err != nil {
|
|
log.Println(err)
|
|
}
|
|
for c := range clients {
|
|
c.closeLink()
|
|
}
|
|
|
|
shutdownTimer = time.After(3 * time.Second)
|
|
inShutdown = true
|
|
}
|
|
|
|
// Forcefully tear down all connections.
|
|
func forceShutdown(reason string) {
|
|
if !inShutdown {
|
|
log.Fatalln("forceShutdown called without initiateShutdown")
|
|
}
|
|
|
|
log.Printf("forced shutdown (%s)\n", reason)
|
|
for c := range clients {
|
|
c.destroy()
|
|
}
|
|
}
|
|
|
|
// --- Client ------------------------------------------------------------------
|
|
|
|
func (c *client) send(line string) {
|
|
if c.conn != nil && !c.closing {
|
|
c.outQ = append(c.outQ, (line + "\r\n")...)
|
|
c.flushOutQ()
|
|
}
|
|
}
|
|
|
|
// Tear down the client connection, trying to do so in a graceful manner.
|
|
func (c *client) closeLink() {
|
|
if c.closing {
|
|
return
|
|
}
|
|
if c.conn == nil {
|
|
c.destroy()
|
|
return
|
|
}
|
|
|
|
// Since we send this goodbye, we don't need to call CloseWrite here.
|
|
c.send("Goodbye")
|
|
c.killTimer = time.AfterFunc(3*time.Second, func() {
|
|
timeouts <- c
|
|
})
|
|
|
|
c.closing = true
|
|
}
|
|
|
|
// Close the connection and forget about the client.
|
|
func (c *client) destroy() {
|
|
// Try to send a "close notify" alert if the TLS object is ready,
|
|
// otherwise just tear down the transport.
|
|
if c.conn != nil {
|
|
_ = c.conn.Close()
|
|
} else {
|
|
_ = c.transport.Close()
|
|
}
|
|
|
|
// Clean up the goroutine, although a spurious event may still be sent.
|
|
if c.killTimer != nil {
|
|
c.killTimer.Stop()
|
|
}
|
|
|
|
log.Println("client destroyed")
|
|
delete(clients, c)
|
|
}
|
|
|
|
// Handle the results from initializing the client's connection.
|
|
func (c *client) onPrepared(isTLS bool) {
|
|
if isTLS {
|
|
c.tls = tls.Server(c.transport, tlsConf)
|
|
c.conn = c.tls
|
|
} else {
|
|
c.conn = c.transport.(connCloseWriter)
|
|
}
|
|
|
|
// TODO: If we've tried to send any data before now, we need to flushOutQ.
|
|
go read(c)
|
|
c.reading = true
|
|
}
|
|
|
|
// Handle the results from trying to read from the client connection.
|
|
func (c *client) onRead(data []byte, readErr error) {
|
|
if !c.reading {
|
|
// Abusing the flag to emulate CloseRead and skip over data, see below.
|
|
return
|
|
}
|
|
|
|
c.inQ = append(c.inQ, data...)
|
|
for {
|
|
advance, token, _ := bufio.ScanLines(c.inQ, false /* atEOF */)
|
|
if advance == 0 {
|
|
break
|
|
}
|
|
|
|
c.inQ = c.inQ[advance:]
|
|
line := string(token)
|
|
fmt.Println(line)
|
|
broadcast(line, c)
|
|
}
|
|
|
|
if readErr != nil {
|
|
c.reading = false
|
|
|
|
if readErr != io.EOF {
|
|
log.Println(readErr)
|
|
c.destroy()
|
|
} else if c.closing {
|
|
// Disregarding whether a clean shutdown has happened or not.
|
|
log.Println("client finished shutdown")
|
|
c.destroy()
|
|
} else {
|
|
log.Println("client EOF")
|
|
c.closeLink()
|
|
}
|
|
} else if len(c.inQ) > 8192 {
|
|
log.Println("client inQ overrun")
|
|
// TODO: Inform the client about inQ overrun in the farewell message.
|
|
c.closeLink()
|
|
|
|
// tls.Conn doesn't have the CloseRead method (and it needs to be able
|
|
// to read from the TCP connection even for writes, so there isn't much
|
|
// sense in expecting the implementation to do anything useful),
|
|
// otherwise we'd use it to block incoming packet data.
|
|
c.reading = false
|
|
}
|
|
}
|
|
|
|
// Spawn a goroutine to flush the outQ if possible and necessary.
|
|
func (c *client) flushOutQ() {
|
|
if !c.writing && c.conn != nil {
|
|
go write(c, c.outQ)
|
|
c.writing = true
|
|
}
|
|
}
|
|
|
|
// Handle the results from trying to write to the client connection.
|
|
func (c *client) onWrite(written int, writeErr error) {
|
|
c.outQ = c.outQ[written:]
|
|
c.writing = false
|
|
|
|
if writeErr != nil {
|
|
log.Println(writeErr)
|
|
c.destroy()
|
|
} else if len(c.outQ) > 0 {
|
|
c.flushOutQ()
|
|
} else if c.closing {
|
|
if c.reading {
|
|
c.conn.CloseWrite()
|
|
} else {
|
|
c.destroy()
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Worker goroutines -------------------------------------------------------
|
|
|
|
func accept(ln net.Listener) {
|
|
for {
|
|
if conn, err := ln.Accept(); err != nil {
|
|
// TODO: Consider specific cases in error handling, some errors
|
|
// are transitional while others are fatal.
|
|
log.Println(err)
|
|
break
|
|
} else {
|
|
conns <- conn
|
|
}
|
|
}
|
|
}
|
|
|
|
func prepare(client *client) {
|
|
conn := client.transport
|
|
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
|
if err != nil {
|
|
// In effect, we require TCP/UDP, as they have port numbers.
|
|
log.Fatalln(err)
|
|
}
|
|
|
|
// The Cgo resolver doesn't pthread_cancel getnameinfo threads, so not
|
|
// bothering with pointless contexts.
|
|
ch := make(chan string, 1)
|
|
go func() {
|
|
defer close(ch)
|
|
if names, err := net.LookupAddr(host); err != nil {
|
|
log.Println(err)
|
|
} else {
|
|
ch <- names[0]
|
|
}
|
|
}()
|
|
|
|
// While we can't cancel it, we still want to set a timeout on it.
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
case resolved, ok := <-ch:
|
|
if ok {
|
|
host = resolved
|
|
}
|
|
}
|
|
|
|
// Note that in this demo application the autodetection prevents non-TLS
|
|
// clients from receiving any messages until they send something.
|
|
isTLS := false
|
|
if sysconn, err := conn.(syscall.Conn).SyscallConn(); err != nil {
|
|
// This is just for the TLS detection and doesn't need to be fatal.
|
|
log.Println(err)
|
|
} else {
|
|
isTLS = detectTLS(sysconn)
|
|
}
|
|
|
|
// FIXME: When the client sends no data, we still initialize its conn.
|
|
prepared <- preparedEvent{client, host, isTLS}
|
|
}
|
|
|
|
func read(client *client) {
|
|
// A new buffer is allocated each time we receive some bytes, because of
|
|
// thread-safety. Therefore the buffer shouldn't be too large, or we'd
|
|
// need to copy it each time into a precisely sized new buffer.
|
|
var err error
|
|
for err == nil {
|
|
var (
|
|
buf [512]byte
|
|
n int
|
|
)
|
|
n, err = client.conn.Read(buf[:])
|
|
reads <- readEvent{client, buf[:n], err}
|
|
}
|
|
}
|
|
|
|
// Flush outQ, which is passed by parameter so that there are no data races.
|
|
func write(client *client, data []byte) {
|
|
// We just write as much as we can, the main goroutine does the looping.
|
|
n, err := client.conn.Write(data)
|
|
writes <- writeEvent{client, n, err}
|
|
}
|
|
|
|
// --- Main --------------------------------------------------------------------
|
|
|
|
func processOneEvent() {
|
|
select {
|
|
case <-sigs:
|
|
if inShutdown {
|
|
forceShutdown("requested by user")
|
|
} else {
|
|
initiateShutdown()
|
|
}
|
|
|
|
case <-shutdownTimer:
|
|
forceShutdown("timeout")
|
|
|
|
case conn := <-conns:
|
|
log.Println("accepted client connection")
|
|
c := &client{transport: conn}
|
|
clients[c] = true
|
|
go prepare(c)
|
|
|
|
case ev := <-prepared:
|
|
log.Println("client is ready, resolved to", ev.host)
|
|
if _, ok := clients[ev.client]; ok {
|
|
ev.client.onPrepared(ev.isTLS)
|
|
}
|
|
|
|
case ev := <-reads:
|
|
log.Println("received data from client")
|
|
if _, ok := clients[ev.client]; ok {
|
|
ev.client.onRead(ev.data, ev.err)
|
|
}
|
|
|
|
case ev := <-writes:
|
|
log.Println("sent data to client")
|
|
if _, ok := clients[ev.client]; ok {
|
|
ev.client.onWrite(ev.written, ev.err)
|
|
}
|
|
|
|
case c := <-timeouts:
|
|
if _, ok := clients[c]; ok {
|
|
log.Println("client timeouted")
|
|
c.destroy()
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
// Just deal with unexpected flags, we don't use any ourselves.
|
|
flag.Parse()
|
|
|
|
if len(flag.Args()) != 3 {
|
|
log.Fatalf("usage: %s KEY CERT ADDRESS\n", os.Args[0])
|
|
}
|
|
|
|
cert, err := tls.LoadX509KeyPair(flag.Arg(1), flag.Arg(0))
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
|
|
tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
listener, err = net.Listen("tcp", flag.Arg(2))
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
|
|
go accept(listener)
|
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
for !inShutdown || len(clients) > 0 {
|
|
processOneEvent()
|
|
}
|
|
}
|