diff --git a/README.adoc b/README.adoc index 0d3d0de..5a87720 100644 --- a/README.adoc +++ b/README.adoc @@ -68,6 +68,9 @@ lxdrgen-mjs.awk:: lxdrgen-swift.awk:: LibertyXDR backend for the Swift programming language. +wdye:: + Compiled Lua-based Expect-like utility, intended purely for build checks. + Contributing and Support ------------------------ Use https://git.janouch.name/p/liberty to report any bugs, request features, diff --git a/tools/wdye/CMakeLists.txt b/tools/wdye/CMakeLists.txt new file mode 100644 index 0000000..9044078 --- /dev/null +++ b/tools/wdye/CMakeLists.txt @@ -0,0 +1,39 @@ +cmake_minimum_required (VERSION 3.18) +project (wdye VERSION 1 DESCRIPTION "What did you expect?" LANGUAGES C) + +set (CMAKE_C_STANDARD 99) +set (CMAKE_C_STANDARD_REQUIRED ON) +set (CMAKE_C_EXTENSIONS OFF) + +# -Wunused-function is pretty annoying here, as everything is static +set (options -Wall -Wextra -Wno-unused-function) +add_compile_options ("$<$:${options}>") +add_compile_options ("$<$:${options}>") + +set (CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/../../cmake") + +find_package (Curses) +find_package (PkgConfig REQUIRED) +pkg_search_module (lua REQUIRED + lua53 lua5.3 lua-5.3 lua54 lua5.4 lua-5.4 lua>=5.3) + +option (WITH_CURSES "Offer terminal sequences using Curses" "${CURSES_FOUND}") + +include_directories ("${PROJECT_BINARY_DIR}") +file (CONFIGURE OUTPUT "${PROJECT_BINARY_DIR}/config.h" CONTENT [[ +#define PROGRAM_NAME "${PROJECT_NAME}" +#define PROGRAM_VERSION "${PROJECT_VERSION}" +#cmakedefine WITH_CURSES +]]) + +add_executable (wdye wdye.c) +target_include_directories (wdye PUBLIC ${lua_INCLUDE_DIRS}) +target_link_directories (wdye PUBLIC ${lua_LIBRARY_DIRS}) +target_link_libraries (wdye PUBLIC ${lua_LIBRARIES}) +if (WITH_CURSES) + target_include_directories (wdye PUBLIC ${CURSES_INCLUDE_DIRS}) + target_link_libraries (wdye PUBLIC ${CURSES_LIBRARIES}) +endif () + +add_test (NAME simple COMMAND wdye "${PROJECT_SOURCE_DIR}/test.lua") +include (CTest) diff --git a/tools/wdye/test.lua b/tools/wdye/test.lua new file mode 100644 index 0000000..5e2fe36 --- /dev/null +++ b/tools/wdye/test.lua @@ -0,0 +1,24 @@ +for k, v in pairs(wdye) do _G[k] = v end + +-- The terminal echoes back, we don't want to read the same stuff twice. +local cat = spawn {"sh", "-c", "cat > /dev/null", environ={TERM="xterm"}} +assert(cat, "failed to spawn process") +assert(cat.term.key_left, "bad terminfo") + +cat:send("Hello\r") +local m = expect(cat:exact {"Hello\r", function (p) return p[0] end}) +assert(m == "Hello\r", "exact match failed, or value expansion mismatch") + +local t = table.pack(expect(timeout {.5, 42})) +assert(#t == 1 and t[1] == 42, "timeout match failed, or value mismatch") + +cat:send("abc123\r") +expect(cat:regex {"A(.*)3", nocase=true, function (p) + assert(p[0] == "abc123", "wrong regex group #0") + assert(p[1] == "bc12", "wrong regex group #1") +end}) + +-- Send EOF (^D), test method chaining. +cat:send("Closing...\r"):send("\004") +local v = expect(cat:eof {true}, + cat:default {.5, function (p) error "expected EOF, got a timeout" end}) diff --git a/tools/wdye/wdye.adoc b/tools/wdye/wdye.adoc new file mode 100644 index 0000000..6282af2 --- /dev/null +++ b/tools/wdye/wdye.adoc @@ -0,0 +1,126 @@ +wdye(1) +======= +:doctype: manpage +:manmanual: wdye Manual +:mansource: wdye {release-version} + +Name +---- +wdye - what did you expect: Lua-based Expect tool + +Synopsis +-------- +*wdye* _program.lua_ + +Description +----------- +*wdye* executes a Lua script, providing an *expect*(1)-like API targeted +at application testing. + +API +--- +This list is logically ordered. Uppercase names represent object types. + +wdye.spawn {file [, arg1, ...] [, environ=env]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Creates a new pseudoterminal, spawns the given program in it, +and returns a _process_ object. When *file* doesn't contain slashes, +the program will be searched for in _PATH_. + +The *env* map may be used to override environment variables, notably _TERM_. +Variables evaluating to _false_ will be removed from the environment. + +The program's whole process group receives SIGKILL when the _process_ +is garbage-collected. + +wdye.expect ([pattern1, ...]) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Waits until any pattern is ready, in order. +When no *timeout* (or *default*) patterns are included, one is added implicitly. + +The function returns the matching _pattern_'s values, while replacing +any included functions with the results of their immediate evaluation, +passing the matching _pattern_ as their sole argument. + +wdye.timeout {[timeout, ] [value1, ...]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Returns a new timeout _pattern_. When no *timeout* is given, which is specified +in seconds, a default timeout value is assumed. Any further values +are remembered to be later processed by *expect*. + +wdye.continue () +~~~~~~~~~~~~~~~~ +Raises a _nil_ error, which is interpreted by *expect* as a signal to restart +all processing. + +PROCESS.buffer +~~~~~~~~~~~~~~ +A string with the _process_' current read buffer contents. + +PROCESS.term +~~~~~~~~~~~~ +A table with the _process_' *terminfo*(5) capabilities, +notably containing all **key_...** codes. +This functionality may not be enabled, then this table will always be empty. + +PROCESS:send ([string, ...]) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Writes the given strings to the _process_' terminal slave, +and returns the _process_ for method chaining. + +Beware of echoing and deadlocks, as only *expect* can read from the _process_, +and thus consume the terminal slave's output queue. + +PROCESS:regex {pattern [, nocase=true] [, notransfer=true] [, value1, ...]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Returns a new regular expression _pattern_. The *pattern* is a POSIX +Extended Regular Expression. Whether it can match NUL bytes depends on your +system C library. + +When the *nocase* option is _true_, the expression will be matched +case-insensitively. + +Unless the *notransfer* option is _true_, all data up until the end of the match +will be erased from the _process_' read buffer upon a successful match. + +PROCESS:exact {literal [, nocase=true] [, notransfer=true] [, value1, ...]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Returns a new literal string _pattern_. This behaves as if the *literal* +had its ERE special characters quoted, and was then passed to *regex*. +This _pattern_ can always match NUL bytes. + +PROCESS:eof {[notransfer=true, ] [value1, ...]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Returns a new end-of-file _pattern_, which matches the entire read buffer +contents once the child process closes the terminal. + +PROCESS:default {[timeout, ] [notransfer=true, ] [value1, ...]} +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Returns a new _pattern_ combining *wdye.timeout* with *eof*. + +PATTERN.process +~~~~~~~~~~~~~~~ +A reference to the _pattern_'s respective process, or _nil_. + +PATTERN[group] +~~~~~~~~~~~~~~ +For patterns that can match data, the zeroth group will be the whole matched +input sequence. +For *regex* patterns, positive groups relate to regular expression subgroups. +Missing groups evaluate to _nil_. + +Example +------- + for k, v in pairs(wdye) do _G[k] = v end + local rot13 = spawn {"tr", "A-Za-z", "N-ZA-Mn-za-m", environ={TERM="dumb"}} + rot13:send "Hello\r" + expect(rot13:exact {"Uryyb\r"}) + +Reporting bugs +-------------- +Use https://git.janouch.name/p/liberty to report bugs, request features, +or submit pull requests. + +See also +-------- +*expect*(1), *terminfo*(5), *regex*(7) diff --git a/tools/wdye/wdye.c b/tools/wdye/wdye.c new file mode 100644 index 0000000..6ac5c0b --- /dev/null +++ b/tools/wdye/wdye.c @@ -0,0 +1,1231 @@ +/* + * wdye.c: what did you expect: Lua-based Expect tool + * + * Copyright (c) 2025, Přemysl Eric Janouch + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION + * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#include "config.h" +#include "../../liberty.c" + +#include + +#include +#include +#include + +#include +#include +#ifdef SOLARIS +#include +#endif + +#ifdef WITH_CURSES +#include +#include +#endif + +static int64_t +clock_msec (void) +{ +#ifdef _POSIX_TIMERS + struct timespec tp; + hard_assert (clock_gettime (CLOCK_BEST, &tp) != -1); + return (int64_t) tp.tv_sec * 1000 + tp.tv_nsec / 1000000; +#else + struct timeval tp; + hard_assert (gettimeofday (&tp, NULL) != -1); + return (int64_t) tp.tv_sec * 1000 + *msec = tp.tv_usec / 1000; +#endif +} + +// execvpe is a GNU extension, reimplement it. +static int +execvpe (const char *file, char *const argv[], char *const envp[]) +{ + const char *path = getenv ("PATH"); + if (strchr (file, '/') || !path) + return execve (file, argv, envp); + + struct strv dirs = strv_make (); + cstr_split (path, ":", false, &dirs); + char *name = NULL; + for (size_t i = 0; i < dirs.len; i++) + { + cstr_set (&name, xstrdup_printf ("%s/%s", + *dirs.vector[i] ? dirs.vector[i] : ".", file)); + execve (name, argv, envp); + } + strv_free (&dirs); + return -1; +} + +// This is a particularly inefficient algorithm, but it can match binary data. +static const char * +memmem (const struct str *haystack, const struct str *needle, bool nocase) +{ + if (haystack->len < needle->len) + return NULL; + + char *xhaystack = xmalloc (haystack->len + 1); + char *xneedle = xmalloc (needle->len + 1); + if (nocase) + { + for (size_t i = 0; i <= haystack->len; i++) + xhaystack[i] = tolower ((uint8_t) haystack->str[i]); + for (size_t i = 0; i <= needle->len; i++) + xneedle[i] = tolower ((uint8_t) needle->str[i]); + } + else + { + memcpy (xhaystack, haystack->str, haystack->len + 1); + memcpy (xneedle, needle->str, needle->len + 1); + } + + const char *result = NULL; + for (size_t i = 0, end = haystack->len - needle->len; i <= end; i++) + if (!memcmp (xhaystack + i, xneedle, needle->len)) + { + result = haystack->str + i; + break; + } + + free (xhaystack); + free (xneedle); + return result; +} + +// --- Pseudoterminal ---------------------------------------------------------- +// This is largely taken from Advanced Programming in the UNIX® Environment, +// just without a bunch of bugs. + +static int +ptym_open (char **pts_name) +{ + int fdm = -1, err = 0; + if ((fdm = posix_openpt (O_RDWR | O_NOCTTY)) < 0) + return -1; + if (grantpt (fdm) < 0 + || unlockpt (fdm) < 0) + goto errout; + + char *ptr = NULL; + if ((ptr = ptsname (fdm)) == NULL) + goto errout; + + cstr_set (pts_name, xstrdup (ptr)); + return fdm; + +errout: + err = errno; + xclose (fdm); + errno = err; + return -1; +} + +static int +ptys_open (const char *pts_name) +{ + int fds = -1; +#ifdef SOLARIS + int err = 0, setup = 0; +#endif + if ((fds = open (pts_name, O_RDWR)) < 0) + return -1; +#ifdef SOLARIS + if ((setup = ioctl (fds, I_FIND, "ldterm")) < 0) + goto errout; + if (setup == 0) + { + if (ioctl (fds, I_PUSH, "ptem") < 0 + || ioctl (fds, I_PUSH, "ldterm") < 0) + goto errout; + + if (ioctl (fds, I_PUSH, "ttcompat") < 0) + { +errout: + err = errno; + xclose (fds); + errno = err; + return -1; + } + } +#endif + return fds; +} + +static pid_t +pty_fork (int *ptrfdm, char **slave_name, + const struct termios *slave_termios, const struct winsize *slave_winsize, + struct error **e) +{ + int fdm = -1, fds = -1; + + char *pts_name = NULL; + if ((fdm = ptym_open (&pts_name)) < 0) + { + error_set (e, "can’t open master pty: %s", strerror (errno)); + return -1; + } + if (slave_name != NULL) + cstr_set (slave_name, xstrdup (pts_name)); + + pid_t pid = fork (); + if (pid < 0) + { + error_set (e, "fork: %s", strerror (errno)); + xclose (fdm); + } + else if (pid != 0) + *ptrfdm = fdm; + else + { + if (setsid () < 0) + exit_fatal ("setsid: %s", strerror (errno)); + if ((fds = ptys_open (pts_name)) < 0) + exit_fatal ("can’t open slave pty: %s", strerror (errno)); + xclose (fdm); + +#if defined BSD + if (ioctl (fds, TIOCSCTTY, (char *) 0) < 0) + exit_fatal ("TIOCSCTTY: %s", strerror (errno)); +#endif + + if (slave_termios != NULL + && tcsetattr (fds, TCSANOW, slave_termios) < 0) + exit_fatal ("tcsetattr error on slave pty: %s", strerror (errno)); + if (slave_winsize != NULL + && ioctl (fds, TIOCSWINSZ, slave_winsize) < 0) + exit_fatal ("TIOCSWINSZ error on slave pty: %s", strerror (errno)); + + if (dup2 (fds, STDIN_FILENO) != STDIN_FILENO) + exit_fatal ("dup2 error to stdin"); + if (dup2 (fds, STDOUT_FILENO) != STDOUT_FILENO) + exit_fatal ("dup2 error to stdout"); + if (dup2 (fds, STDERR_FILENO) != STDERR_FILENO) + exit_fatal ("dup2 error to stderr"); + if (fds != STDIN_FILENO && fds != STDOUT_FILENO && fds != STDERR_FILENO) + xclose (fds); + } + free (pts_name); + return pid; +} + +// --- Global state ------------------------------------------------------------ + +static struct +{ + lua_State *L; ///< Lua state + lua_Number default_timeout; ///< Default expect timeout (s) +} +g = +{ + .default_timeout = 10., +}; + +static int +xlua_error_handler (lua_State *L) +{ + // Don't add tracebacks when there's already one, and pass nil through. + const char *string = luaL_optstring (L, 1, NULL); + if (string && !strchr (string, '\n')) + { + luaL_traceback (L, L, string, 1); + lua_remove (L, 1); + } + return 1; +} + +static bool +xlua_getfield (lua_State *L, int idx, const char *name, + int expected, bool optional) +{ + int found = lua_getfield (L, idx, name); + if (found == expected) + return true; + if (optional && found == LUA_TNIL) + return false; + + const char *message = optional + ? "invalid field \"%s\" (found: %s, expected: %s or nil)" + : "invalid or missing field \"%s\" (found: %s, expected: %s)"; + return luaL_error (L, message, name, + lua_typename (L, found), lua_typename (L, expected)); +} + +static void +xlua_newtablecopy (lua_State *L, int idx, int first, int last) +{ + int len = last - first + 1; + lua_createtable (L, len, 0); + if (idx < 0) + idx--; + + for (lua_Integer i = 0; i < len; i++) + { + lua_rawgeti (L, idx, first + i); + lua_rawseti (L, -2, 1 + i); + } +} + +// --- Patterns ---------------------------------------------------------------- + +#define XLUA_PATTERN_METATABLE "pattern" + +enum pattern_kind +{ + PATTERN_REGEX, ///< Regular expression match + PATTERN_EXACT, ///< Literal string match + PATTERN_TIMEOUT, ///< Timeout + PATTERN_EOF, ///< EOF condition + PATTERN_DEFAULT, ///< Either timeout or EOF condition +}; + +struct pattern +{ + enum pattern_kind kind; ///< Tag + int ref_process; ///< Process for all except TIMEOUT + struct process *process; ///< Weak pointer to the process + regex_t *regex; ///< Regular expression for REGEX + struct str exact; ///< Exact match literal for EXACT + lua_Number timeout; ///< Timeout for TIMEOUT/DEFAULT (s) + bool nocase; ///< Case insensitive search + bool notransfer; ///< Do not consume process buffer + int ref_values; ///< Return values as a table reference + + // Patterns are constructed in place, used once, and forgotten, + // so we can just shove anything extra in here. + + struct error *e; ///< Error buffer + struct str input; ///< Matched input + regmatch_t *matches; ///< Match indexes within the input + bool eof; ///< End of file seen +}; + +static struct pattern * +pattern_new (lua_State *L, enum pattern_kind kind, int idx_process) +{ + struct pattern *self = lua_newuserdata (L, sizeof *self); + luaL_setmetatable (L, XLUA_PATTERN_METATABLE); + memset (self, 0, sizeof *self); + + self->kind = kind; + self->ref_process = LUA_NOREF; + self->exact = str_make (); + self->timeout = -1.; + self->ref_values = LUA_NOREF; + self->input = str_make (); + + if (idx_process) + { + lua_pushvalue (L, idx_process); + self->process = lua_touserdata (L, -1); + self->ref_process = luaL_ref (L, LUA_REGISTRYINDEX); + } + return self; +} + +static int +xlua_pattern_gc (lua_State *L) +{ + struct pattern *self = luaL_checkudata (L, 1, XLUA_PATTERN_METATABLE); + luaL_unref (L, LUA_REGISTRYINDEX, self->ref_process); + if (self->regex) + regex_free (self->regex); + str_free (&self->exact); + luaL_unref (L, LUA_REGISTRYINDEX, self->ref_values); + if (self->e) + error_free (self->e); + str_free (&self->input); + free (self->matches); + return 0; +} + +static int +xlua_pattern_index (lua_State *L) +{ + struct pattern *self = luaL_checkudata (L, 1, XLUA_PATTERN_METATABLE); + if (!lua_isinteger (L, 2)) + { + const char *key = luaL_checkstring (L, 2); + if (!strcmp (key, "process")) + lua_rawgeti (L, LUA_REGISTRYINDEX, self->ref_process); + else + return luaL_argerror (L, 2, "not a readable property"); + return 1; + } + + lua_Integer group = lua_tointeger (L, 2); + switch (self->kind) + { + case PATTERN_REGEX: + { + const regmatch_t *m = self->matches + group; + if (group < 0 || group > self->regex->re_nsub + || m->rm_so < 0 || m->rm_eo < 0 || m->rm_eo > self->input.len) + lua_pushnil (L); + else + lua_pushlstring (L, + self->input.str + m->rm_so, m->rm_eo - m->rm_so); + return 1; + } + case PATTERN_EXACT: + case PATTERN_EOF: + case PATTERN_DEFAULT: + if (group != 0) + lua_pushnil (L); + else + lua_pushlstring (L, self->input.str, self->input.len); + return 1; + default: + return luaL_argerror (L, 2, "indexing unavailable for this pattern"); + } +} + +static bool +pattern_readtimeout (struct pattern *self, lua_State *L, int idx) +{ + lua_rawgeti (L, idx, 1); + bool ok = lua_isnumber (L, -1); + lua_Number v = lua_tonumber (L, -1); + lua_pop (L, 1); + if (v != v) + luaL_error (L, "timeout is not a number"); + if (ok) + self->timeout = v; + return ok; +} + +static void +pattern_readflags (struct pattern *self, lua_State *L, int idx) +{ + lua_getfield (L, idx, "nocase"); + self->nocase = lua_toboolean (L, -1); + lua_getfield (L, idx, "notransfer"); + self->notransfer = lua_toboolean (L, -1); + lua_pop (L, 2); +} + +static luaL_Reg xlua_pattern_table[] = +{ + { "__gc", xlua_pattern_gc }, + { "__index", xlua_pattern_index }, + { NULL, NULL } +}; + +// --- Process ----------------------------------------------------------------- + +#define XLUA_PROCESS_METATABLE "process" + +struct process +{ + int terminal_fd; ///< Process stdin/stdout/stderr + pid_t pid; ///< Process ID + int ref_term; ///< Terminal information + struct str buffer; ///< Terminal input buffer +}; + +static struct process * +process_new (lua_State *L) +{ + struct process *self = lua_newuserdata (L, sizeof *self); + luaL_setmetatable (L, XLUA_PROCESS_METATABLE); + memset (self, 0, sizeof *self); + + self->terminal_fd = -1; + self->pid = -1; + self->ref_term = LUA_NOREF; + self->buffer = str_make (); + return self; +} + +static int +xlua_process_gc (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + if (self->terminal_fd != -1) + xclose (self->terminal_fd); + if (self->pid != -1) + // The slave is in its own process group. + kill (-self->pid, SIGKILL); + luaL_unref (L, LUA_REGISTRYINDEX, self->ref_term); + str_free (&self->buffer); + return 0; +} + +static int +xlua_process_index (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + const char *key = luaL_checkstring (L, 2); + if (*key != '_' && luaL_getmetafield (L, 1, key)) + return 1; + + if (!strcmp (key, "buffer")) + lua_pushlstring (L, self->buffer.str, self->buffer.len); + else if (!strcmp (key, "term")) + lua_rawgeti (L, LUA_REGISTRYINDEX, self->ref_term); + else + return luaL_argerror (L, 2, "not a readable property"); + return 1; +} + +static int +xlua_process_send (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + int nargs = lua_gettop (L); + for (int i = 2; i <= nargs; i++) + if (!lua_isstring (L, i)) + return luaL_argerror (L, i, "need string arguments"); + + for (int i = 2; i <= nargs; i++) + { + size_t len = 0; + const char *arg = lua_tolstring (L, i, &len); + ssize_t written = write (self->terminal_fd, arg, len); + if (written == -1) + return luaL_error (L, "write failed: %s", strerror (errno)); + else if (written != len) + return luaL_error (L, "write failed: %s", "short write"); + } + lua_pushvalue (L, 1); + return 1; +} + +static int +xlua_process_regex (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + luaL_checktype (L, 2, LUA_TTABLE); + if (lua_gettop (L) != 2) + return luaL_error (L, "too many arguments"); + + struct pattern *pattern = pattern_new (L, PATTERN_REGEX, 1); + pattern_readflags (pattern, L, 2); + + int flags = REG_EXTENDED; + if (pattern->nocase) + flags |= REG_ICASE; + + lua_rawgeti (L, 2, 1); + if (!lua_isstring (L, -1)) + return luaL_error (L, "expected regular expression"); + + size_t len = 0; + const char *re = lua_tolstring (L, -1, &len); + if (!(pattern->regex = regex_compile (re, flags, &pattern->e))) + return luaL_error (L, "%s", pattern->e->message); + lua_pop (L, 1); + + pattern->matches = + xcalloc (pattern->regex->re_nsub + 1, sizeof *pattern->matches); + + xlua_newtablecopy (L, 2, 2, lua_rawlen (L, 2)); + pattern->ref_values = luaL_ref (L, LUA_REGISTRYINDEX); + return 1; +} + +static int +xlua_process_exact (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + luaL_checktype (L, 2, LUA_TTABLE); + if (lua_gettop (L) != 2) + return luaL_error (L, "too many arguments"); + + struct pattern *pattern = pattern_new (L, PATTERN_EXACT, 1); + pattern_readflags (pattern, L, 2); + + lua_rawgeti (L, 2, 1); + if (!lua_isstring (L, -1)) + return luaL_error (L, "expected string literal"); + + size_t len = 0; + const char *literal = lua_tolstring (L, -1, &len); + str_append_data (&pattern->exact, literal, len); + lua_pop (L, 1); + + xlua_newtablecopy (L, 2, 2, lua_rawlen (L, 2)); + pattern->ref_values = luaL_ref (L, LUA_REGISTRYINDEX); + return 1; +} + +static int +xlua_process_eof (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + luaL_checktype (L, 2, LUA_TTABLE); + if (lua_gettop (L) != 2) + return luaL_error (L, "too many arguments"); + + struct pattern *pattern = pattern_new (L, PATTERN_EOF, 1); + pattern_readflags (pattern, L, 2); + + xlua_newtablecopy (L, 2, 1, lua_rawlen (L, 2)); + pattern->ref_values = luaL_ref (L, LUA_REGISTRYINDEX); + return 1; +} + +static int +xlua_process_default (lua_State *L) +{ + struct process *self = luaL_checkudata (L, 1, XLUA_PROCESS_METATABLE); + luaL_checktype (L, 2, LUA_TTABLE); + if (lua_gettop (L) != 2) + return luaL_error (L, "too many arguments"); + + struct pattern *pattern = pattern_new (L, PATTERN_DEFAULT, 1); + pattern_readflags (pattern, L, 2); + + int first = 1, last = lua_rawlen (L, 2); + if (pattern_readtimeout (pattern, L, 2)) + first++; + + xlua_newtablecopy (L, 2, first, last); + pattern->ref_values = luaL_ref (L, LUA_REGISTRYINDEX); + return 1; +} + +static bool +process_feed (struct process *self) +{ + // Let's do this without O_NONBLOCK for now. + char buf[BUFSIZ] = ""; + ssize_t n = read (self->terminal_fd, buf, sizeof buf); + if (n < 0) + { + if (errno == EINTR) + return true; +#if __linux__ + // https://unix.stackexchange.com/a/538271 + if (errno == EIO) + return false; +#endif + + print_warning ("read: %s", strerror (errno)); + return false; + } + + // TODO(p): Add match_max processing, limiting the buffer size. + str_append_data (&self->buffer, buf, n); + return n > 0; +} + +static luaL_Reg xlua_process_table[] = +{ + { "__gc", xlua_process_gc }, + { "__index", xlua_process_index }, + { "send", xlua_process_send }, + { "regex", xlua_process_regex }, + { "exact", xlua_process_exact }, + { "eof", xlua_process_eof }, + { "default", xlua_process_default }, + { NULL, NULL } +}; + +// --- Terminal ---------------------------------------------------------------- + +#ifdef WITH_CURSES + +static bool +load_terminfo (const char *term, struct str_map *strings) +{ + // TODO(p): See if we can get away with using a bogus FD. + int err = 0; + TERMINAL *saved_term = set_curterm (NULL); + if (setupterm ((char *) term, STDOUT_FILENO, &err) != OK) + { + set_curterm (saved_term); + return false; + } + + for (size_t i = 0; strfnames[i]; i++) + { + const char *value = tigetstr (strnames[i]); + if (value && value != (char *) -1) + str_map_set (strings, strfnames[i], xstrdup (value)); + } + del_curterm (set_curterm (saved_term)); + return true; +} + +#endif + +// --- Library ----------------------------------------------------------------- + +struct spawn_context +{ + struct str_map env; ///< Subprocess environment map + struct str_map term; ///< terminfo database strings + struct strv envv; ///< Subprocess environment vector + struct strv argv; ///< Subprocess argument vector + + struct error *error; ///< Error +}; + +static struct spawn_context +spawn_context_make (void) +{ + struct spawn_context self = {}; + self.env = str_map_make (free); + self.term = str_map_make (free); + + // XXX: It might make sense to enable starting from an empty environment. + for (char **p = environ; *p; p++) + { + const char *equals = strchr (*p, '='); + if (!equals) + continue; + + char *key = xstrndup (*p, equals - *p); + str_map_set (&self.env, key, xstrdup (equals + 1)); + free (key); + } + self.envv = strv_make (); + self.argv = strv_make (); + return self; +} + +static void +spawn_context_free (struct spawn_context *self) +{ + str_map_free (&self->env); + str_map_free (&self->term); + strv_free (&self->envv); + strv_free (&self->argv); + + if (self->error) + error_free (self->error); +} + +// -0, +0, e +static void +environ_map_update (struct str_map *env, lua_State *L) +{ + lua_pushnil (L); + while (lua_next (L, -2)) + { + if (lua_type (L, -2) != LUA_TSTRING) + luaL_error (L, "environment maps must be keyed by strings"); + + const char *value = lua_tostring (L, -1); + str_map_set (env, lua_tostring (L, -2), + value ? xstrdup (value) : NULL); + lua_pop (L, 1); + } +} + +// The environment will get pseudo-randomly reordered, +// which is fine by POSIX. +static void +environ_map_serialize (struct str_map *env, struct strv *envv) +{ + struct str_map_iter iter = str_map_iter_make (env); + const char *value; + while ((value = str_map_iter_next (&iter))) + strv_append_owned (envv, + xstrdup_printf ("%s=%s", iter.link->key, value)); +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static int +spawn_protected (lua_State *L) +{ + struct spawn_context *ctx = lua_touserdata (L, 1); + + // Step 1: Prepare process environment. + if (xlua_getfield (L, 2, "environ", LUA_TTABLE, true)) + { + environ_map_update (&ctx->env, L); + lua_pop (L, 1); + } + char *term = str_map_find (&ctx->env, "TERM"); + if (!term) + { + print_debug ("setting a default TERM"); + str_map_set (&ctx->env, "TERM", (term = xstrdup ("dumb"))); + } + environ_map_serialize (&ctx->env, &ctx->envv); + +#ifdef WITH_CURSES + // Step 2: Load terminal information. + if (!load_terminfo (term, &ctx->term)) + luaL_error (L, "failed to initialize terminfo for %s", term); +#endif + + // Step 3: Prepare process command line. + size_t argc = lua_rawlen (L, 2); + for (size_t i = 1; i <= argc; i++) + { + lua_pushinteger (L, i); + lua_rawget (L, 2); + const char *arg = lua_tostring (L, -1); + if (!arg) + return luaL_error (L, "spawn arguments must be strings"); + + strv_append (&ctx->argv, arg); + lua_pop (L, 1); + } + if (ctx->argv.len < 1) + return luaL_error (L, "missing argument"); + + // Step 4: Create a process object. + // This will get garbage collected as appropriate on failure. + struct process *process = process_new (L); + + // This could be made into an object that can adjust winsize/termios. + lua_createtable (L, 0, ctx->term.len); + struct str_map_iter iter = str_map_iter_make (&ctx->term); + const char *value = NULL; + while ((value = str_map_iter_next (&iter))) + { + lua_pushstring (L, iter.link->key); + lua_pushstring (L, value); + lua_settable (L, -3); + } + process->ref_term = luaL_ref (L, LUA_REGISTRYINDEX); + + // Step 5: Spawn the process, which gets a new process group. + process->pid = + pty_fork (&process->terminal_fd, NULL, NULL, NULL, &ctx->error); + if (process->pid < 0) + { + return luaL_error (L, "failed to spawn %s: %s", + ctx->argv.vector[0], ctx->error->message); + } + if (!process->pid) + { + execvpe (ctx->argv.vector[0], ctx->argv.vector, ctx->envv.vector); + print_error ("failed to spawn %s: %s", + ctx->argv.vector[0], strerror (errno)); + _exit (EXIT_FAILURE); + } + + set_cloexec (process->terminal_fd); + return 1; +} + +static int +xlua_spawn (lua_State *L) +{ + luaL_checktype (L, 1, LUA_TTABLE); + + lua_pushcfunction (L, xlua_error_handler); + lua_pushcfunction (L, spawn_protected); + + // There are way too many opportunities for Lua to throw, + // so maintain a context to clean up in one go. + struct spawn_context ctx = spawn_context_make (); + lua_pushlightuserdata (L, &ctx); + lua_rotate (L, 1, -1); + int result = lua_pcall (L, 2, 1, -4); + spawn_context_free (&ctx); + if (result) + return lua_error (L); + + // Remove the error handler ("good programming practice"). + lua_remove (L, -2); + return 1; +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +struct expect_context +{ + size_t patterns_len; ///< Number of patterns + struct pattern **patterns; ///< Pattern array + size_t pfds_len; ///< Number of distinct poll FDs + struct pollfd *pfds; ///< Distinct poll FDs + + lua_Number first_timeout; ///< Nearest timeout value + lua_Number timeout; ///< Actually used timeout value +}; + +static void +expect_context_free (struct expect_context *self) +{ + free (self->patterns); + free (self->pfds); +} + +static bool +expect_has_fd (struct expect_context *ctx, int fd) +{ + for (size_t i = 0; i < ctx->pfds_len; i++) + if (ctx->pfds[i].fd == fd) + return true; + return false; +} + +static struct process * +expect_fd_to_process (struct expect_context *ctx, int fd) +{ + for (size_t i = 0; i < ctx->patterns_len; i++) + { + struct pattern *p = ctx->patterns[i]; + if (p->process + && p->process->terminal_fd == fd) + return p->process; + } + return NULL; +} + +static void +expect_set_fd_eof (struct expect_context *ctx, int fd) +{ + for (size_t i = 0; i < ctx->patterns_len; i++) + { + struct pattern *p = ctx->patterns[i]; + if (p->process + && p->process->terminal_fd == fd) + p->eof = true; + } +} + +static void +expect_prepare_pattern (struct expect_context *ctx, struct pattern *p) +{ + str_reset (&p->input); + if (p->kind == PATTERN_REGEX) + for (int i = 0; i <= p->regex->re_nsub; i++) + p->matches[i] = (regmatch_t) { .rm_so = -1, .rm_eo = -1 }; + + if (p->kind == PATTERN_REGEX + || p->kind == PATTERN_EXACT + || p->kind == PATTERN_EOF + || p->kind == PATTERN_DEFAULT) + { + p->eof = false; + if (!expect_has_fd (ctx, p->process->terminal_fd)) + ctx->pfds[ctx->pfds_len++] = (struct pollfd) + { .fd = p->process->terminal_fd, .events = POLLIN }; + } + if (p->kind == PATTERN_TIMEOUT + || p->kind == PATTERN_DEFAULT) + { + lua_Number v = p->timeout >= 0 ? p->timeout : g.default_timeout; + if (ctx->first_timeout != ctx->first_timeout) + ctx->first_timeout = v; + else + ctx->first_timeout = MIN (ctx->first_timeout, v); + } +} + +static void +expect_prepare (struct expect_context *ctx) +{ + // The liberty poller is not particularly appropriate for this use case. + ctx->pfds_len = 0; + ctx->pfds = xcalloc (ctx->patterns_len, sizeof *ctx->pfds); + + ctx->first_timeout = NAN; + for (size_t i = 0; i < ctx->patterns_len; i++) + expect_prepare_pattern (ctx, ctx->patterns[i]); + + // There is always at least a default timeout. + ctx->timeout = g.default_timeout; + if (ctx->first_timeout == ctx->first_timeout) + ctx->timeout = ctx->first_timeout; +} + +static struct pattern * +expect_match_timeout (struct expect_context *ctx) +{ + for (size_t i = 0; i < ctx->patterns_len; i++) + { + struct pattern *p = ctx->patterns[i]; + if (p->kind != PATTERN_TIMEOUT + && p->kind != PATTERN_DEFAULT) + continue; + + if (p->timeout <= ctx->first_timeout) + return p; + } + return NULL; +} + +static bool +pattern_match (struct pattern *self) +{ + struct process *process = self->process; + struct str *buffer = process ? &process->buffer : NULL; + + str_reset (&self->input); + switch (self->kind) + { + case PATTERN_EOF: + case PATTERN_DEFAULT: + { + if (!self->eof) + return false; + + str_append_str (&self->input, &process->buffer); + if (!self->notransfer) + str_reset (&process->buffer); + return true; + } + case PATTERN_REGEX: + { + int flags = 0; +#ifdef REG_STARTEND + self->matches[0] = (regmatch_t) { .rm_so = 0, .rm_eo = buffer->len }; + flags |= REG_STARTEND; +#endif + if (regexec (self->regex, buffer->str, + self->regex->re_nsub + 1, self->matches, flags)) + { + for (int i = 0; i <= self->regex->re_nsub; i++) + self->matches[i] = (regmatch_t) { .rm_so = -1, .rm_eo = -1 }; + return false; + } + + str_append_data (&self->input, buffer->str, self->matches[0].rm_eo); + if (!self->notransfer) + str_remove_slice (buffer, 0, self->matches[0].rm_eo); + return true; + } + case PATTERN_EXACT: + { + const char *match = memmem (buffer, &self->exact, self->nocase); + if (!match) + return false; + + str_append_data (&self->input, match, self->exact.len); + if (!self->notransfer) + str_remove_slice (buffer, 0, match - buffer->str + self->exact.len); + return true; + } + default: + return false; + } +} + +static struct pattern * +expect_match_data (struct expect_context *ctx) +{ + for (size_t i = 0; i < ctx->patterns_len; i++) + { + struct pattern *p = ctx->patterns[i]; + if (pattern_match (p)) + return p; + } + return NULL; +} + +static int +expect_protected (lua_State *L) +{ + struct expect_context *ctx = lua_touserdata (L, lua_upvalueindex (1)); + ctx->patterns_len = lua_gettop (L); + ctx->patterns = xcalloc (ctx->patterns_len, sizeof *ctx->patterns); + for (size_t i = 0; i < ctx->patterns_len; i++) + ctx->patterns[i] = luaL_checkudata (L, i + 1, XLUA_PATTERN_METATABLE); + expect_prepare (ctx); + + int64_t deadline = 0; + struct pattern *match = NULL; +restart: + // A "continue" statement means we start anew with a new timeout. + // TODO(p): We should detect deadline > INT64_MAX, and wait indefinitely. + deadline = clock_msec () + ctx->timeout * 1000; + + // First, check if anything matches already, + // so that we don't need to wait for /even more/ data. + match = expect_match_data (ctx); + + while (!match) + { + int64_t until_deadline = deadline - clock_msec (); + int n = poll (ctx->pfds, ctx->pfds_len, MAX (0, until_deadline)); + if (n < 0) + return luaL_error (L, "poll: %s", strerror (errno)); + + for (int i = 0; i < n; i++) + { + struct pollfd *pfd = ctx->pfds + i; + hard_assert (!(pfd->revents & POLLNVAL)); + if (!(pfd->revents & (POLLIN | POLLHUP | POLLERR))) + continue; + + struct process *process = expect_fd_to_process (ctx, pfd->fd); + hard_assert (process != NULL); + if (!process_feed (process)) + { + expect_set_fd_eof (ctx, pfd->fd); + // Otherwise we would loop around this descriptor. + pfd->fd = -1; + } + } + + if (n > 0) + match = expect_match_data (ctx); + else if (!(match = expect_match_timeout (ctx))) + return 0; + } + + // Resolve the matching pattern back to its Lua full userdata. + int match_idx = 0; + for (size_t i = 0; i < ctx->patterns_len; i++) + if (ctx->patterns[i] == match) + match_idx = i + 1; + + // Filter the values table by executing any functions with the pattern. + lua_rawgeti (L, LUA_REGISTRYINDEX, match->ref_values); + int values_idx = lua_gettop (L); + int values_len = lua_rawlen (L, values_idx); + lua_checkstack (L, values_len); + + lua_pushcfunction (L, xlua_error_handler); + int handler_idx = lua_gettop (L); + for (int i = 1; i <= values_len; i++) + { + lua_rawgeti (L, values_idx, i); + if (!lua_isfunction (L, -1)) + continue; + + lua_pushvalue (L, match_idx); + if (!lua_pcall (L, 1, LUA_MULTRET, handler_idx)) + continue; + if (!lua_isnil (L, -1)) + return lua_error (L); + + lua_pop (L, lua_gettop (L) - values_idx + 1); + goto restart; + } + return lua_gettop (L) - handler_idx; +} + +static int +xlua_expect (lua_State *L) +{ + lua_pushcfunction (L, xlua_error_handler); + lua_insert (L, 1); + + struct expect_context ctx = {}; + lua_pushlightuserdata (L, &ctx); + lua_pushcclosure (L, expect_protected, 1); + lua_insert (L, 2); + + int result = lua_pcall (L, lua_gettop (L) - 2, LUA_MULTRET, 1); + expect_context_free (&ctx); + if (result) + return lua_error (L); + + // Remove the error handler ("good programming practice"). + lua_remove (L, 1); + return lua_gettop (L); +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +static int +xlua_timeout (lua_State *L) +{ + luaL_checktype (L, 1, LUA_TTABLE); + if (lua_gettop (L) != 1) + return luaL_error (L, "too many arguments"); + + struct pattern *pattern = pattern_new (L, PATTERN_TIMEOUT, 0); + + int first = 1, last = lua_rawlen (L, 1); + if (pattern_readtimeout (pattern, L, 1)) + first++; + + xlua_newtablecopy (L, 1, first, last); + pattern->ref_values = luaL_ref (L, LUA_REGISTRYINDEX); + return 1; +} + +static int +xlua_continue (lua_State *L) +{ + // xlua_expect() handles this specially. + lua_pushnil (L); + return lua_error (L); +} + +static luaL_Reg xlua_library[] = +{ + { "spawn", xlua_spawn }, + { "expect", xlua_expect }, + { "timeout", xlua_timeout }, + { "continue", xlua_continue }, + { NULL, NULL } +}; + +// --- Initialisation, event handling ------------------------------------------ + +static void * +xlua_alloc (void *ud, void *ptr, size_t o_size, size_t n_size) +{ + (void) ud; + (void) o_size; + + if (n_size) + return realloc (ptr, n_size); + + free (ptr); + return NULL; +} + +static int +xlua_panic (lua_State *L) +{ + print_fatal ("Lua panicked: %s", lua_tostring (L, -1)); + lua_close (L); + exit (EXIT_FAILURE); + return 0; +} + +int +main (int argc, char *argv[]) +{ + if (argc != 2) + { + fprintf (stderr, "Usage: %s program.lua\n", argv[0]); + return 1; + } + + if (!(g.L = lua_newstate (xlua_alloc, NULL))) + exit_fatal ("Lua initialization failed"); + lua_atpanic (g.L, xlua_panic); + luaL_openlibs (g.L); + luaL_checkversion (g.L); + + luaL_newlib (g.L, xlua_library); + lua_setglobal (g.L, PROGRAM_NAME); + + luaL_newmetatable (g.L, XLUA_PROCESS_METATABLE); + luaL_setfuncs (g.L, xlua_process_table, 0); + lua_pop (g.L, 1); + + luaL_newmetatable (g.L, XLUA_PATTERN_METATABLE); + luaL_setfuncs (g.L, xlua_pattern_table, 0); + lua_pop (g.L, 1); + + const char *path = argv[1]; + lua_pushcfunction (g.L, xlua_error_handler); + if (luaL_loadfile (g.L, path) + || lua_pcall (g.L, 0, 0, -2)) + { + print_error ("%s: %s", path, lua_tostring (g.L, -1)); + lua_pop (g.L, 1); + lua_close (g.L); + return 1; + } + lua_close (g.L); + return 0; +}