diff --git a/hid/main_test.go b/hid/main_test.go new file mode 100644 index 0000000..90a55f6 --- /dev/null +++ b/hid/main_test.go @@ -0,0 +1,139 @@ +// +// Copyright (c) 2015 - 2018, Přemysl Janouch +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +// + +package main + +import ( + "crypto/tls" + "net" + "os" + "reflect" + "syscall" + "testing" +) + +func TestSplitString(t *testing.T) { + var splitStringTests = []struct { + s, delims string + ignoreEmpty bool + result []string + }{ + {",a,,bc", ",", false, []string{"", "a", "", "bc"}}, + {",a,,bc", ",", true, []string{"a", "bc"}}, + {"a,;bc,", ",;", false, []string{"a", "", "bc", ""}}, + {"a,;bc,", ",;", true, []string{"a", "bc"}}, + {"", ",", false, []string{""}}, + {"", ",", true, nil}, + } + + for i, d := range splitStringTests { + got := splitString(d.s, d.delims, d.ignoreEmpty) + if !reflect.DeepEqual(got, d.result) { + t.Errorf("case %d: %v should be %v\n", i, got, d.result) + } + } +} + +func socketpair() (*os.File, *os.File, error) { + pair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, nil, err + } + + // See go #24331, this makes 1.11 use the internal poller + // while there wasn't a way to achieve that before. + if err := syscall.SetNonblock(int(pair[0]), true); err != nil { + return nil, nil, err + } + if err := syscall.SetNonblock(int(pair[1]), true); err != nil { + return nil, nil, err + } + + fa := os.NewFile(uintptr(pair[0]), "a") + if fa == nil { + return nil, nil, os.ErrInvalid + } + + fb := os.NewFile(uintptr(pair[1]), "b") + if fb == nil { + fa.Close() + return nil, nil, os.ErrInvalid + } + + return fa, fb, nil +} + +func TestDetectTLS(t *testing.T) { + detectTLSFromFunc := func(t *testing.T, writer func(net.Conn)) bool { + // net.Pipe doesn't use file descriptors, we need a socketpair. + sockA, sockB, err := socketpair() + if err != nil { + t.Fatal(err) + } + defer sockA.Close() + defer sockB.Close() + + fcB, err := net.FileConn(sockB) + if err != nil { + t.Fatal(err) + } + go writer(fcB) + + fcA, err := net.FileConn(sockA) + if err != nil { + t.Fatal(err) + } + sc, err := fcA.(syscall.Conn).SyscallConn() + if err != nil { + t.Fatal(err) + } + return detectTLS(sc) + } + + t.Run("SSL_2.0", func(t *testing.T) { + if !detectTLSFromFunc(t, func(fc net.Conn) { + // The obsolete, useless, unsupported SSL 2.0 record format. + _, _ = fc.Write([]byte{0x80, 0x01, 0x01}) + }) { + t.Error("could not detect SSL") + } + }) + t.Run("crypto_tls", func(t *testing.T) { + if !detectTLSFromFunc(t, func(fc net.Conn) { + conn := tls.Client(fc, &tls.Config{InsecureSkipVerify: true}) + _ = conn.Handshake() + }) { + t.Error("could not detect TLS") + } + }) + t.Run("text", func(t *testing.T) { + if detectTLSFromFunc(t, func(fc net.Conn) { + _, _ = fc.Write([]byte("ПРЕВЕД")) + }) { + t.Error("detected UTF-8 as TLS") + } + }) + t.Run("EOF", func(t *testing.T) { + type connCloseWriter interface { + net.Conn + CloseWrite() error + } + if detectTLSFromFunc(t, func(fc net.Conn) { + _ = fc.(connCloseWriter).CloseWrite() + }) { + t.Error("detected EOF as TLS") + } + }) +}