//
// Copyright (c) 2018, Přemysl Eric Janouch <p@janouch.name>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
//

// hnc is a netcat-alike that shuts down properly.
package main

import (
	"crypto/tls"
	"flag"
	"fmt"
	"io"
	"net"
	"os"
)

// #include <unistd.h>
import "C"

func isatty(fd uintptr) bool { return C.isatty(C.int(fd)) != 0 }

func log(format string, args ...interface{}) {
	msg := fmt.Sprintf(format+"\n", args...)
	if isatty(os.Stderr.Fd()) {
		msg = "\x1b[0;1;31m" + msg + "\x1b[m"
	}
	os.Stderr.WriteString(msg)
}

var (
	flagTLS  = flag.Bool("tls", false, "connect using TLS")
	flagCRLF = flag.Bool("crlf", false, "translate LF into CRLF")
)

// Network connection that can shut down the write end.
type connCloseWriter interface {
	net.Conn
	CloseWrite() error
}

func dial(address string) (connCloseWriter, error) {
	if *flagTLS {
		return tls.Dial("tcp", address, &tls.Config{
			InsecureSkipVerify: true,
		})
	}
	transport, err := net.Dial("tcp", address)
	if err != nil {
		return nil, err
	}
	return transport.(connCloseWriter), nil
}

func expand(raw []byte) []byte {
	if !*flagCRLF {
		return raw
	}
	var res []byte
	for _, b := range raw {
		if b == '\n' {
			res = append(res, '\r')
		}
		res = append(res, b)
	}
	return res
}

// Asynchronously delivered result of io.Reader.
type readResult struct {
	b   []byte
	err error
}

func read(r io.Reader, ch chan<- readResult) {
	defer close(ch)
	for {
		var buf [8192]byte
		n, err := r.Read(buf[:])
		ch <- readResult{buf[:n], err}
		if err != nil {
			break
		}
	}
}

func main() {
	flag.Usage = func() {
		fmt.Fprintf(flag.CommandLine.Output(),
			"Usage: %s [OPTION]... HOST PORT\n"+
				"Connect to a remote host over TCP/IP.\n", os.Args[0])
		flag.PrintDefaults()
	}

	flag.Parse()
	if flag.NArg() != 2 {
		flag.Usage()
		os.Exit(2)
	}

	conn, err := dial(net.JoinHostPort(flag.Arg(0), flag.Arg(1)))
	if err != nil {
		log("dial: %s", err)
		os.Exit(1)
	}

	fromUser := make(chan readResult)
	go read(os.Stdin, fromUser)

	fromConn := make(chan readResult)
	go read(conn, fromConn)

	for fromUser != nil || fromConn != nil {
		select {
		case result := <-fromUser:
			if len(result.b) > 0 {
				if _, err := conn.Write(expand(result.b)); err != nil {
					log("remote: %s", err)
				}
			}
			if result.err != nil {
				log("stdin: %s", result.err)
				fromUser = nil
				if err := conn.CloseWrite(); err != nil {
					log("remote: %s", err)
				}
			}
		case result := <-fromConn:
			if len(result.b) > 0 {
				if _, err := os.Stdout.Write(result.b); err != nil {
					log("stdout: %s", err)
				}
			}
			if result.err != nil {
				log("remote: %s", result.err)
				fromConn = nil
			}
		}
	}
}