diff --git a/acid.go b/acid.go index 280323c..64d69dc 100644 --- a/acid.go +++ b/acid.go @@ -22,6 +22,7 @@ import ( "os/signal" "sort" "strconv" + "strings" "sync" "syscall" ttemplate "text/template" @@ -361,7 +362,23 @@ func handlePush(w http.ResponseWriter, r *http.Request) { const rpcHeaderSignature = "X-ACID-Signature" -func rpcRestart(w io.Writer, ids []int64) { +var errWrongUsage = errors.New("wrong usage") + +func rpcRestart(ctx context.Context, + w io.Writer, fs *flag.FlagSet, args []string) error { + if err := fs.Parse(args); err != nil { + return err + } + + ids := []int64{} + for _, arg := range fs.Args() { + id, err := strconv.ParseInt(arg, 10, 64) + if err != nil { + return fmt.Errorf("%w: %s", errWrongUsage, err) + } + ids = append(ids, id) + } + gRunningMutex.Lock() defer gRunningMutex.Unlock() @@ -373,7 +390,7 @@ func rpcRestart(w io.Writer, ids []int64) { // The executor bumps to "running" after inserting into gRunning, // so we should not need to exclude that state here. - result, err := gDB.ExecContext(context.Background(), `UPDATE task + result, err := gDB.ExecContext(ctx, `UPDATE task SET state = ?, detail = '', notified = 0 WHERE id = ?`, taskStateNew, id) if err != nil { @@ -384,6 +401,35 @@ func rpcRestart(w io.Writer, ids []int64) { } notifierAwaken() executorAwaken() + return nil +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +var rpcCommands = map[string]struct { + // handler must not write anything when returning an error. + handler func(context.Context, io.Writer, *flag.FlagSet, []string) error + usage string + function string +}{ + "restart": {rpcRestart, "ID...", + "Schedule tasks with the given IDs to be rerun."}, +} + +func rpcPrintCommands(w io.Writer) { + // The alphabetic ordering is unfortunate, but tolerable. + keys := []string{} + for key := range rpcCommands { + keys = append(keys, key) + } + sort.Strings(keys) + + fmt.Fprintf(w, "Commands:\n") + for _, key := range keys { + cmd := rpcCommands[key] + fmt.Fprintf(w, " %s [OPTION...] %s\n \t%s\n", + key, cmd.usage, cmd.function) + } } func handleRPC(w http.ResponseWriter, r *http.Request) { @@ -410,21 +456,43 @@ func handleRPC(w http.ResponseWriter, r *http.Request) { return } + // Our handling closely follows what the flag package does internally. + command, args := args[0], args[1:] - switch command { - case "restart": - ids := []int64{} - for _, arg := range args { - id, err := strconv.ParseInt(arg, 10, 64) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - ids = append(ids, id) - } - rpcRestart(w, ids) - default: - http.Error(w, "Unknown command: "+command, http.StatusBadRequest) + cmd, ok := rpcCommands[command] + if !ok { + http.Error(w, "unknown command: "+command, http.StatusBadRequest) + rpcPrintCommands(w) + return + } + + // If we redirected the FlagSet straight to the response, + // we would be unable to set our own HTTP status. + b := bytes.NewBuffer(nil) + + fs := flag.NewFlagSet(command, flag.ContinueOnError) + fs.SetOutput(b) + fs.Usage = func() { + fmt.Fprintf(fs.Output(), + "Usage: %s [OPTION...] %s\n%s\n", + fs.Name(), cmd.usage, cmd.function) + fs.PrintDefaults() + } + + err = cmd.handler(r.Context(), w, fs, args) + + // Wrap this error to make it as if fs.Parse discovered the issue. + if errors.Is(err, errWrongUsage) { + fmt.Fprintln(fs.Output(), err) + fs.Usage() + } + + // The flag package first prints all errors that it returns. + // If the buffer ends up not being empty, flush it into the request. + if b.Len() != 0 { + http.Error(w, strings.TrimSpace(b.String()), http.StatusBadRequest) + } else if err != nil { + http.Error(w, err.Error(), http.StatusUnprocessableEntity) } }