haven/hnc/main.go

151 lines
3.2 KiB
Go

//
// Copyright (c) 2018, Přemysl Eric Janouch <p@janouch.name>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
//
// hnc is a netcat-alike that shuts down properly.
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"net"
"os"
)
// #include <unistd.h>
import "C"
func isatty(fd uintptr) bool { return C.isatty(C.int(fd)) != 0 }
func log(format string, args ...interface{}) {
msg := fmt.Sprintf(format+"\n", args...)
if isatty(os.Stderr.Fd()) {
msg = "\x1b[0;1;31m" + msg + "\x1b[m"
}
os.Stderr.WriteString(msg)
}
var (
flagTLS = flag.Bool("tls", false, "connect using TLS")
flagCRLF = flag.Bool("crlf", false, "translate LF into CRLF")
)
// Network connection that can shut down the write end.
type connCloseWriter interface {
net.Conn
CloseWrite() error
}
func dial(address string) (connCloseWriter, error) {
if *flagTLS {
return tls.Dial("tcp", address, &tls.Config{
InsecureSkipVerify: true,
})
}
transport, err := net.Dial("tcp", address)
if err != nil {
return nil, err
}
return transport.(connCloseWriter), nil
}
func expand(raw []byte) []byte {
if !*flagCRLF {
return raw
}
var res []byte
for _, b := range raw {
if b == '\n' {
res = append(res, '\r')
}
res = append(res, b)
}
return res
}
// Asynchronously delivered result of io.Reader.
type readResult struct {
b []byte
err error
}
func read(r io.Reader, ch chan<- readResult) {
defer close(ch)
for {
var buf [8192]byte
n, err := r.Read(buf[:])
ch <- readResult{buf[:n], err}
if err != nil {
break
}
}
}
func main() {
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(),
"Usage: %s [OPTION]... HOST PORT\n"+
"Connect to a remote host over TCP/IP.\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if flag.NArg() != 2 {
flag.Usage()
os.Exit(2)
}
conn, err := dial(net.JoinHostPort(flag.Arg(0), flag.Arg(1)))
if err != nil {
log("dial: %s", err)
os.Exit(1)
}
fromUser := make(chan readResult)
go read(os.Stdin, fromUser)
fromConn := make(chan readResult)
go read(conn, fromConn)
for fromUser != nil || fromConn != nil {
select {
case result := <-fromUser:
if len(result.b) > 0 {
if _, err := conn.Write(expand(result.b)); err != nil {
log("remote: %s", err)
}
}
if result.err != nil {
log("stdin: %s", result.err)
fromUser = nil
if err := conn.CloseWrite(); err != nil {
log("remote: %s", err)
}
}
case result := <-fromConn:
if len(result.b) > 0 {
if _, err := os.Stdout.Write(result.b); err != nil {
log("stdout: %s", err)
}
}
if result.err != nil {
log("remote: %s", result.err)
fromConn = nil
}
}
}
}