Rewrite RPC handling for wider usability
This commit is contained in:
		
							parent
							
								
									013e7eba28
								
							
						
					
					
						commit
						eda0f22f07
					
				
							
								
								
									
										100
									
								
								acid.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								acid.go
									
									
									
									
									
								
							| @ -22,6 +22,7 @@ import ( | ||||
| 	"os/signal" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"syscall" | ||||
| 	ttemplate "text/template" | ||||
| @ -361,7 +362,23 @@ func handlePush(w http.ResponseWriter, r *http.Request) { | ||||
| 
 | ||||
| const rpcHeaderSignature = "X-ACID-Signature" | ||||
| 
 | ||||
| func rpcRestart(w io.Writer, ids []int64) { | ||||
| var errWrongUsage = errors.New("wrong usage") | ||||
| 
 | ||||
| func rpcRestart(ctx context.Context, | ||||
| 	w io.Writer, fs *flag.FlagSet, args []string) error { | ||||
| 	if err := fs.Parse(args); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ids := []int64{} | ||||
| 	for _, arg := range fs.Args() { | ||||
| 		id, err := strconv.ParseInt(arg, 10, 64) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("%w: %s", errWrongUsage, err) | ||||
| 		} | ||||
| 		ids = append(ids, id) | ||||
| 	} | ||||
| 
 | ||||
| 	gRunningMutex.Lock() | ||||
| 	defer gRunningMutex.Unlock() | ||||
| 
 | ||||
| @ -373,7 +390,7 @@ func rpcRestart(w io.Writer, ids []int64) { | ||||
| 
 | ||||
| 		// The executor bumps to "running" after inserting into gRunning, | ||||
| 		// so we should not need to exclude that state here. | ||||
| 		result, err := gDB.ExecContext(context.Background(), `UPDATE task | ||||
| 		result, err := gDB.ExecContext(ctx, `UPDATE task | ||||
| 			SET state = ?, detail = '', notified = 0 WHERE id = ?`, | ||||
| 			taskStateNew, id) | ||||
| 		if err != nil { | ||||
| @ -384,6 +401,35 @@ func rpcRestart(w io.Writer, ids []int64) { | ||||
| 	} | ||||
| 	notifierAwaken() | ||||
| 	executorAwaken() | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - | ||||
| 
 | ||||
| var rpcCommands = map[string]struct { | ||||
| 	// handler must not write anything when returning an error. | ||||
| 	handler  func(context.Context, io.Writer, *flag.FlagSet, []string) error | ||||
| 	usage    string | ||||
| 	function string | ||||
| }{ | ||||
| 	"restart": {rpcRestart, "ID...", | ||||
| 		"Schedule tasks with the given IDs to be rerun."}, | ||||
| } | ||||
| 
 | ||||
| func rpcPrintCommands(w io.Writer) { | ||||
| 	// The alphabetic ordering is unfortunate, but tolerable. | ||||
| 	keys := []string{} | ||||
| 	for key := range rpcCommands { | ||||
| 		keys = append(keys, key) | ||||
| 	} | ||||
| 	sort.Strings(keys) | ||||
| 
 | ||||
| 	fmt.Fprintf(w, "Commands:\n") | ||||
| 	for _, key := range keys { | ||||
| 		cmd := rpcCommands[key] | ||||
| 		fmt.Fprintf(w, "  %s [OPTION...] %s\n    \t%s\n", | ||||
| 			key, cmd.usage, cmd.function) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func handleRPC(w http.ResponseWriter, r *http.Request) { | ||||
| @ -410,21 +456,43 @@ func handleRPC(w http.ResponseWriter, r *http.Request) { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Our handling closely follows what the flag package does internally. | ||||
| 
 | ||||
| 	command, args := args[0], args[1:] | ||||
| 	switch command { | ||||
| 	case "restart": | ||||
| 		ids := []int64{} | ||||
| 		for _, arg := range args { | ||||
| 			id, err := strconv.ParseInt(arg, 10, 64) | ||||
| 			if err != nil { | ||||
| 				http.Error(w, err.Error(), http.StatusBadRequest) | ||||
| 				return | ||||
| 			} | ||||
| 			ids = append(ids, id) | ||||
| 		} | ||||
| 		rpcRestart(w, ids) | ||||
| 	default: | ||||
| 		http.Error(w, "Unknown command: "+command, http.StatusBadRequest) | ||||
| 	cmd, ok := rpcCommands[command] | ||||
| 	if !ok { | ||||
| 		http.Error(w, "unknown command: "+command, http.StatusBadRequest) | ||||
| 		rpcPrintCommands(w) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// If we redirected the FlagSet straight to the response, | ||||
| 	// we would be unable to set our own HTTP status. | ||||
| 	b := bytes.NewBuffer(nil) | ||||
| 
 | ||||
| 	fs := flag.NewFlagSet(command, flag.ContinueOnError) | ||||
| 	fs.SetOutput(b) | ||||
| 	fs.Usage = func() { | ||||
| 		fmt.Fprintf(fs.Output(), | ||||
| 			"Usage: %s [OPTION...] %s\n%s\n", | ||||
| 			fs.Name(), cmd.usage, cmd.function) | ||||
| 		fs.PrintDefaults() | ||||
| 	} | ||||
| 
 | ||||
| 	err = cmd.handler(r.Context(), w, fs, args) | ||||
| 
 | ||||
| 	// Wrap this error to make it as if fs.Parse discovered the issue. | ||||
| 	if errors.Is(err, errWrongUsage) { | ||||
| 		fmt.Fprintln(fs.Output(), err) | ||||
| 		fs.Usage() | ||||
| 	} | ||||
| 
 | ||||
| 	// The flag package first prints all errors that it returns. | ||||
| 	// If the buffer ends up not being empty, flush it into the request. | ||||
| 	if b.Len() != 0 { | ||||
| 		http.Error(w, strings.TrimSpace(b.String()), http.StatusBadRequest) | ||||
| 	} else if err != nil { | ||||
| 		http.Error(w, err.Error(), http.StatusUnprocessableEntity) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user