xK/xP/xP.go

218 lines
4.7 KiB
Go
Raw Normal View History

// Copyright (c) 2022, Přemysl Eric Janouch <p@janouch.name>
// SPDX-License-Identifier: 0BSD
package main
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"html/template"
"io"
"log"
"net"
"net/http"
"os"
"time"
"nhooyr.io/websocket"
)
var (
addressBind string
addressConnect string
addressWS string
)
func clientToRelay(
ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
t, b, err := ws.Read(ctx)
if err != nil {
log.Println("Command receive failed: " + err.Error())
return false
}
if t != websocket.MessageText {
log.Println("Command receive failed: " +
"binary messages are not supported")
return false
}
log.Printf("?> %s\n", b)
var m RelayCommandMessage
if err := json.Unmarshal(b, &m); err != nil {
log.Println("Command unmarshalling failed: " + err.Error())
return false
}
b, ok := m.AppendTo(make([]byte, 4))
if !ok {
log.Println("Command serialization failed")
return false
}
binary.BigEndian.PutUint32(b[:4], uint32(len(b)-4))
if _, err := conn.Write(b); err != nil {
log.Println("Command send failed: " + err.Error())
return false
}
log.Printf("-> %v\n", b)
return true
}
func relayToClient(
ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
var length uint32
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
log.Println("Event receive failed: " + err.Error())
return false
}
b := make([]byte, length)
if _, err := io.ReadFull(conn, b); err != nil {
log.Println("Event receive failed: " + err.Error())
return false
}
log.Printf("<? %v\n", b)
var m RelayEventMessage
if after, ok := m.ConsumeFrom(b); !ok {
log.Println("Event deserialization failed")
return false
} else if len(after) != 0 {
log.Println("Event deserialization failed: trailing data")
return false
}
j, err := json.Marshal(&m)
if err != nil {
log.Println("Event marshalling failed: " + err.Error())
return false
}
if err := ws.Write(ctx, websocket.MessageText, j); err != nil {
log.Println("Event send failed: " + err.Error())
return false
}
log.Printf("<- %s\n", j)
return true
}
func errorToClient(ctx context.Context, ws *websocket.Conn, err error) bool {
j, err := json.Marshal(&RelayEventMessage{
EventSeq: 0,
Data: RelayEventData{
Interface: RelayEventDataError{
Event: RelayEventError,
CommandSeq: 0,
Error: err.Error(),
},
},
})
if err != nil {
log.Println("Event marshalling failed: " + err.Error())
return false
}
if err := ws.Write(ctx, websocket.MessageText, j); err != nil {
log.Println("Event send failed: " + err.Error())
return false
}
return true
}
func handleWS(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
InsecureSkipVerify: true,
CompressionMode: websocket.CompressionContextTakeover,
// This is for the payload, and happens to trigger on all messages.
CompressionThreshold: 16,
})
if err != nil {
log.Println("Client rejected: " + err.Error())
return
}
defer ws.Close(websocket.StatusGoingAway, "Goodbye")
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
conn, err := net.Dial("tcp", addressConnect)
if err != nil {
errorToClient(ctx, ws, err)
return
}
// We don't need to intervene, so it's just two separate pipes so far.
go func() {
for clientToRelay(ctx, ws, conn) {
}
cancel()
}()
go func() {
for relayToClient(ctx, ws, conn) {
}
cancel()
}()
<-ctx.Done()
}
var staticHandler = http.FileServer(http.Dir("."))
var page = template.Must(template.New("/").Parse(`<!DOCTYPE html>
<html>
<head>
<title>xP</title>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="xP.css" />
</head>
<body>
<script src="mithril.js">
</script>
<script>
let proxy = '{{ . }}'
</script>
<script src="xP.js">
</script>
</body>
</html>`))
func handleDefault(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
staticHandler.ServeHTTP(w, r)
return
}
wsURI := addressWS
if wsURI == "" {
wsURI = fmt.Sprintf("ws://%s/ws", r.Host)
}
if err := page.Execute(w, wsURI); err != nil {
log.Println("Template execution failed: " + err.Error())
}
}
func main() {
if len(os.Args) < 3 || len(os.Args) > 4 {
log.Fatalf("usage: %s BIND CONNECT [WSURI]\n", os.Args[0])
}
addressBind, addressConnect = os.Args[1], os.Args[2]
if len(os.Args) > 3 {
addressWS = os.Args[3]
}
http.Handle("/ws", http.HandlerFunc(handleWS))
http.Handle("/", http.HandlerFunc(handleDefault))
s := &http.Server{
Addr: addressBind,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
MaxHeaderBytes: 32 << 10,
}
log.Fatal(s.ListenAndServe())
}