300 lines
6.3 KiB
Go
300 lines
6.3 KiB
Go
// Copyright (c) 2022, Přemysl Eric Janouch <p@janouch.name>
|
|
// SPDX-License-Identifier: 0BSD
|
|
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
var (
|
|
debug = flag.Bool("debug", false, "enable debug output")
|
|
|
|
addressBind string
|
|
addressConnect string
|
|
addressWS string
|
|
)
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
func relayReadFrame(r io.Reader) []byte {
|
|
var length uint32
|
|
if err := binary.Read(r, binary.BigEndian, &length); err != nil {
|
|
log.Println("Event receive failed: " + err.Error())
|
|
return nil
|
|
}
|
|
b := make([]byte, length)
|
|
if _, err := io.ReadFull(r, b); err != nil {
|
|
log.Println("Event receive failed: " + err.Error())
|
|
return nil
|
|
}
|
|
|
|
if *debug {
|
|
log.Printf("<? %v\n", b)
|
|
|
|
var m RelayEventMessage
|
|
if after, ok := m.ConsumeFrom(b); !ok {
|
|
log.Println("Event deserialization failed")
|
|
return nil
|
|
} else if len(after) != 0 {
|
|
log.Println("Event deserialization failed: trailing data")
|
|
return nil
|
|
}
|
|
|
|
j, err := m.MarshalJSON()
|
|
if err != nil {
|
|
log.Println("Event marshalling failed: " + err.Error())
|
|
return nil
|
|
}
|
|
|
|
log.Printf("<- %s\n", j)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func relayMakeReceiver(ctx context.Context, conn net.Conn) <-chan []byte {
|
|
// The usual event message rarely gets above 1 kilobyte,
|
|
// thus this is set to buffer up at most 1 megabyte or so.
|
|
p := make(chan []byte, 1000)
|
|
r := bufio.NewReaderSize(conn, 65536)
|
|
go func() {
|
|
defer close(p)
|
|
for {
|
|
j := relayReadFrame(r)
|
|
if j == nil {
|
|
return
|
|
}
|
|
select {
|
|
case p <- j:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return p
|
|
}
|
|
|
|
func relayWriteJSON(conn net.Conn, j []byte) bool {
|
|
var m RelayCommandMessage
|
|
if err := json.Unmarshal(j, &m); err != nil {
|
|
log.Println("Command unmarshalling failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
b, ok := m.AppendTo(make([]byte, 4))
|
|
if !ok {
|
|
log.Println("Command serialization failed")
|
|
return false
|
|
}
|
|
binary.BigEndian.PutUint32(b[:4], uint32(len(b)-4))
|
|
if _, err := conn.Write(b); err != nil {
|
|
log.Println("Command send failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
if *debug {
|
|
log.Printf("-> %v\n", b)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
func clientReadJSON(ctx context.Context, ws *websocket.Conn) []byte {
|
|
t, j, err := ws.Read(ctx)
|
|
if err != nil {
|
|
log.Println("Command receive failed: " + err.Error())
|
|
return nil
|
|
}
|
|
if t != websocket.MessageText {
|
|
log.Println(
|
|
"Command receive failed: " + "binary messages are not supported")
|
|
return nil
|
|
}
|
|
|
|
if *debug {
|
|
log.Printf("?> %s\n", j)
|
|
}
|
|
return j
|
|
}
|
|
|
|
func clientWriteBinary(ctx context.Context, ws *websocket.Conn, b []byte) bool {
|
|
if err := ws.Write(ctx, websocket.MessageBinary, b); err != nil {
|
|
log.Println("Event send failed: " + err.Error())
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func clientWriteError(ctx context.Context, ws *websocket.Conn, err error) bool {
|
|
b, ok := (&RelayEventMessage{
|
|
EventSeq: 0,
|
|
Data: RelayEventData{
|
|
Interface: RelayEventDataError{
|
|
Event: RelayEventError,
|
|
CommandSeq: 0,
|
|
Error: err.Error(),
|
|
},
|
|
},
|
|
}).AppendTo(nil)
|
|
if ok {
|
|
log.Println("Event serialization failed")
|
|
return false
|
|
}
|
|
return clientWriteBinary(ctx, ws, b)
|
|
}
|
|
|
|
func handleWS(w http.ResponseWriter, r *http.Request) {
|
|
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
InsecureSkipVerify: true,
|
|
// Note that Safari can be broken with compression.
|
|
CompressionMode: websocket.CompressionContextTakeover,
|
|
// This is for the payload; set higher to avoid overhead.
|
|
CompressionThreshold: 64 << 10,
|
|
})
|
|
if err != nil {
|
|
log.Println("Client rejected: " + err.Error())
|
|
return
|
|
}
|
|
defer ws.Close(websocket.StatusGoingAway, "Goodbye")
|
|
|
|
ctx, cancel := context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
conn, err := net.Dial("tcp", addressConnect)
|
|
if err != nil {
|
|
log.Println("Connection failed: " + err.Error())
|
|
clientWriteError(ctx, ws, err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// To decrease latencies, events are received and decoded in parallel
|
|
// to their sending, and we try to batch them together.
|
|
relayFrames := relayMakeReceiver(ctx, conn)
|
|
batchFrames := func() []byte {
|
|
batch, ok := <-relayFrames
|
|
if !ok {
|
|
return nil
|
|
}
|
|
Batch:
|
|
for {
|
|
select {
|
|
case b, ok := <-relayFrames:
|
|
if !ok {
|
|
break Batch
|
|
}
|
|
batch = append(batch, b...)
|
|
default:
|
|
break Batch
|
|
}
|
|
}
|
|
return batch
|
|
}
|
|
|
|
// We don't need to intervene, so it's just two separate pipes so far.
|
|
go func() {
|
|
defer cancel()
|
|
for {
|
|
j := clientReadJSON(ctx, ws)
|
|
if j == nil {
|
|
return
|
|
}
|
|
relayWriteJSON(conn, j)
|
|
}
|
|
}()
|
|
go func() {
|
|
defer cancel()
|
|
for {
|
|
b := batchFrames()
|
|
if b == nil {
|
|
return
|
|
}
|
|
clientWriteBinary(ctx, ws, b)
|
|
}
|
|
}()
|
|
<-ctx.Done()
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
|
|
var staticHandler = http.FileServer(http.Dir("."))
|
|
|
|
var page = template.Must(template.New("/").Parse(`<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>xP</title>
|
|
<meta charset="utf-8" />
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
<link rel="stylesheet" href="xP.css" />
|
|
</head>
|
|
<body>
|
|
<script src="mithril.js">
|
|
</script>
|
|
<script>
|
|
let proxy = '{{ . }}'
|
|
</script>
|
|
<script type="module" src="xP.js">
|
|
</script>
|
|
</body>
|
|
</html>`))
|
|
|
|
func handleDefault(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/" {
|
|
staticHandler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
wsURI := addressWS
|
|
if wsURI == "" {
|
|
wsURI = fmt.Sprintf("ws://%s/ws", r.Host)
|
|
}
|
|
if err := page.Execute(w, wsURI); err != nil {
|
|
log.Println("Template execution failed: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(flag.CommandLine.Output(),
|
|
"Usage: %s [OPTION...] BIND CONNECT [WSURI]\n\n", os.Args[0])
|
|
flag.PrintDefaults()
|
|
}
|
|
|
|
flag.Parse()
|
|
if flag.NArg() < 2 || flag.NArg() > 3 {
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
addressBind, addressConnect = flag.Arg(0), flag.Arg(1)
|
|
if flag.NArg() > 2 {
|
|
addressWS = flag.Arg(2)
|
|
}
|
|
|
|
http.Handle("/ws", http.HandlerFunc(handleWS))
|
|
http.Handle("/", http.HandlerFunc(handleDefault))
|
|
|
|
s := &http.Server{
|
|
Addr: addressBind,
|
|
ReadTimeout: 60 * time.Second,
|
|
WriteTimeout: 60 * time.Second,
|
|
MaxHeaderBytes: 32 << 10,
|
|
}
|
|
log.Fatal(s.ListenAndServe())
|
|
}
|