218 lines
4.7 KiB
Go
218 lines
4.7 KiB
Go
// Copyright (c) 2022, Přemysl Eric Janouch <p@janouch.name>
|
|
// SPDX-License-Identifier: 0BSD
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
var (
|
|
addressBind string
|
|
addressConnect string
|
|
addressWS string
|
|
)
|
|
|
|
func clientToRelay(
|
|
ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
|
|
t, b, err := ws.Read(ctx)
|
|
if err != nil {
|
|
log.Println("Command receive failed: " + err.Error())
|
|
return false
|
|
}
|
|
if t != websocket.MessageText {
|
|
log.Println("Command receive failed: " +
|
|
"binary messages are not supported")
|
|
return false
|
|
}
|
|
|
|
log.Printf("?> %s\n", b)
|
|
|
|
var m RelayCommandMessage
|
|
if err := json.Unmarshal(b, &m); err != nil {
|
|
log.Println("Command unmarshalling failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
b, ok := m.AppendTo(make([]byte, 4))
|
|
if !ok {
|
|
log.Println("Command serialization failed")
|
|
return false
|
|
}
|
|
binary.BigEndian.PutUint32(b[:4], uint32(len(b)-4))
|
|
if _, err := conn.Write(b); err != nil {
|
|
log.Println("Command send failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
log.Printf("-> %v\n", b)
|
|
return true
|
|
}
|
|
|
|
func relayToClient(
|
|
ctx context.Context, ws *websocket.Conn, conn net.Conn) bool {
|
|
var length uint32
|
|
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
log.Println("Event receive failed: " + err.Error())
|
|
return false
|
|
}
|
|
b := make([]byte, length)
|
|
if _, err := io.ReadFull(conn, b); err != nil {
|
|
log.Println("Event receive failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
log.Printf("<? %v\n", b)
|
|
|
|
var m RelayEventMessage
|
|
if after, ok := m.ConsumeFrom(b); !ok {
|
|
log.Println("Event deserialization failed")
|
|
return false
|
|
} else if len(after) != 0 {
|
|
log.Println("Event deserialization failed: trailing data")
|
|
return false
|
|
}
|
|
|
|
j, err := json.Marshal(&m)
|
|
if err != nil {
|
|
log.Println("Event marshalling failed: " + err.Error())
|
|
return false
|
|
}
|
|
if err := ws.Write(ctx, websocket.MessageText, j); err != nil {
|
|
log.Println("Event send failed: " + err.Error())
|
|
return false
|
|
}
|
|
|
|
log.Printf("<- %s\n", j)
|
|
return true
|
|
}
|
|
|
|
func errorToClient(ctx context.Context, ws *websocket.Conn, err error) bool {
|
|
j, err := json.Marshal(&RelayEventMessage{
|
|
EventSeq: 0,
|
|
Data: RelayEventData{
|
|
Interface: RelayEventDataError{
|
|
Event: RelayEventError,
|
|
CommandSeq: 0,
|
|
Error: err.Error(),
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
log.Println("Event marshalling failed: " + err.Error())
|
|
return false
|
|
}
|
|
if err := ws.Write(ctx, websocket.MessageText, j); err != nil {
|
|
log.Println("Event send failed: " + err.Error())
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func handleWS(w http.ResponseWriter, r *http.Request) {
|
|
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
InsecureSkipVerify: true,
|
|
CompressionMode: websocket.CompressionContextTakeover,
|
|
// This is for the payload, and happens to trigger on all messages.
|
|
CompressionThreshold: 16,
|
|
})
|
|
if err != nil {
|
|
log.Println("Client rejected: " + err.Error())
|
|
return
|
|
}
|
|
|
|
defer ws.Close(websocket.StatusGoingAway, "Goodbye")
|
|
|
|
ctx, cancel := context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
conn, err := net.Dial("tcp", addressConnect)
|
|
if err != nil {
|
|
errorToClient(ctx, ws, err)
|
|
return
|
|
}
|
|
|
|
// We don't need to intervene, so it's just two separate pipes so far.
|
|
go func() {
|
|
for clientToRelay(ctx, ws, conn) {
|
|
}
|
|
cancel()
|
|
}()
|
|
go func() {
|
|
for relayToClient(ctx, ws, conn) {
|
|
}
|
|
cancel()
|
|
}()
|
|
<-ctx.Done()
|
|
}
|
|
|
|
var staticHandler = http.FileServer(http.Dir("."))
|
|
|
|
var page = template.Must(template.New("/").Parse(`<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>xP</title>
|
|
<meta charset="utf-8" />
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
<link rel="stylesheet" href="xP.css" />
|
|
</head>
|
|
<body>
|
|
<script src="mithril.js">
|
|
</script>
|
|
<script>
|
|
let proxy = '{{ . }}'
|
|
</script>
|
|
<script src="xP.js">
|
|
</script>
|
|
</body>
|
|
</html>`))
|
|
|
|
func handleDefault(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/" {
|
|
staticHandler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
wsURI := addressWS
|
|
if wsURI == "" {
|
|
wsURI = fmt.Sprintf("ws://%s/ws", r.Host)
|
|
}
|
|
if err := page.Execute(w, wsURI); err != nil {
|
|
log.Println("Template execution failed: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
if len(os.Args) < 3 || len(os.Args) > 4 {
|
|
log.Fatalf("usage: %s BIND CONNECT [WSURI]\n", os.Args[0])
|
|
}
|
|
|
|
addressBind, addressConnect = os.Args[1], os.Args[2]
|
|
if len(os.Args) > 3 {
|
|
addressWS = os.Args[3]
|
|
}
|
|
|
|
http.Handle("/ws", http.HandlerFunc(handleWS))
|
|
http.Handle("/", http.HandlerFunc(handleDefault))
|
|
|
|
s := &http.Server{
|
|
Addr: addressBind,
|
|
ReadTimeout: 60 * time.Second,
|
|
WriteTimeout: 60 * time.Second,
|
|
MaxHeaderBytes: 32 << 10,
|
|
}
|
|
log.Fatal(s.ListenAndServe())
|
|
}
|