Compare commits
	
		
			5 Commits
		
	
	
		
			054078908a
			...
			77e988365d
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 77e988365d | |||
| 8df76dbaab | |||
| 819d2d80e0 | |||
| 36f6612603 | |||
| b4f28814b7 | 
							
								
								
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								LICENSE
									
									
									
									
									
								
							| @ -1,4 +1,4 @@ | |||||||
| Copyright (c) 2023, Přemysl Eric Janouch <p@janouch.name> | Copyright (c) 2023 - 2024, Přemysl Eric Janouch <p@janouch.name> | ||||||
| 
 | 
 | ||||||
| Permission to use, copy, modify, and/or distribute this software for any | Permission to use, copy, modify, and/or distribute this software for any | ||||||
| purpose with or without fee is hereby granted. | purpose with or without fee is hereby granted. | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								deeptagger/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								deeptagger/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | # Ubuntu 20.04 LTS | ||||||
|  | cmake_minimum_required (VERSION 3.16) | ||||||
|  | project (deeptagger VERSION 0.0.1 LANGUAGES CXX) | ||||||
|  | 
 | ||||||
|  | # Hint: set ONNXRuntime_ROOT to a directory with a pre-built GitHub release. | ||||||
|  | # (Useful for development, otherwise you may need to adjust the rpath.) | ||||||
|  | set (CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}") | ||||||
|  | 
 | ||||||
|  | find_package (ONNXRuntime REQUIRED) | ||||||
|  | find_package (PkgConfig REQUIRED) | ||||||
|  | pkg_check_modules (GM REQUIRED GraphicsMagick++) | ||||||
|  | 
 | ||||||
|  | add_executable (deeptagger deeptagger.cpp) | ||||||
|  | target_compile_features (deeptagger PRIVATE cxx_std_17) | ||||||
|  | target_include_directories (deeptagger PRIVATE | ||||||
|  | 	${GM_INCLUDE_DIRS} ${ONNXRuntime_INCLUDE_DIRS}) | ||||||
|  | target_link_directories (deeptagger PRIVATE | ||||||
|  | 	${GM_LIBRARY_DIRS}) | ||||||
|  | target_link_libraries (deeptagger PRIVATE | ||||||
|  | 	${GM_LIBRARIES} ${ONNXRuntime_LIBRARIES}) | ||||||
							
								
								
									
										11
									
								
								deeptagger/FindONNXRuntime.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								deeptagger/FindONNXRuntime.cmake
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | |||||||
|  | # Public Domain | ||||||
|  | 
 | ||||||
|  | find_path (ONNXRuntime_INCLUDE_DIRS onnxruntime_c_api.h | ||||||
|  | 	PATH_SUFFIXES onnxruntime) | ||||||
|  | find_library (ONNXRuntime_LIBRARIES NAMES onnxruntime) | ||||||
|  | 
 | ||||||
|  | include (FindPackageHandleStandardArgs) | ||||||
|  | FIND_PACKAGE_HANDLE_STANDARD_ARGS (ONNXRuntime DEFAULT_MSG | ||||||
|  | 	ONNXRuntime_INCLUDE_DIRS ONNXRuntime_LIBRARIES) | ||||||
|  | 
 | ||||||
|  | mark_as_advanced (ONNXRuntime_LIBRARIES ONNXRuntime_INCLUDE_DIRS) | ||||||
							
								
								
									
										130
									
								
								deeptagger/README.adoc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								deeptagger/README.adoc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,130 @@ | |||||||
|  | deeptagger | ||||||
|  | ========== | ||||||
|  | 
 | ||||||
|  | This is an automatic image tagger/classifier written in C++, | ||||||
|  | primarily targeting various anime models. | ||||||
|  | 
 | ||||||
|  | Unfortunately, you will still need Python 3, as well as some luck, to prepare | ||||||
|  | the models, achieved by running download.sh.  You will need about 20 gigabytes | ||||||
|  | of space for this operation. | ||||||
|  | 
 | ||||||
|  | "WaifuDiffusion v1.4" models are officially distributed with ONNX model exports | ||||||
|  | that do not support symbolic batch sizes.  The script attempts to fix this | ||||||
|  | by running custom exports. | ||||||
|  | 
 | ||||||
|  | You're invited to change things to suit your particular needs. | ||||||
|  | 
 | ||||||
|  | Getting it to work | ||||||
|  | ------------------ | ||||||
|  | To build the evaluator, install a C++ compiler, CMake, and development packages | ||||||
|  | of GraphicsMagick and ONNX Runtime. | ||||||
|  | 
 | ||||||
|  | Prebuilt ONNX Runtime can be most conveniently downloaded from | ||||||
|  | https://github.com/microsoft/onnxruntime/releases[GitHub releases]. | ||||||
|  | Remember to also install CUDA packages, such as _nvidia-cudnn_ on Debian, | ||||||
|  | if you plan on using the GPU-enabled options. | ||||||
|  | 
 | ||||||
|  |  $ cmake -DONNXRuntime_ROOT=/path/to/onnxruntime -B build | ||||||
|  |  $ cmake --build build | ||||||
|  |  $ ./download.sh | ||||||
|  |  $ build/deeptagger models/deepdanbooru-v3-20211112-sgd-e28.model image.jpg | ||||||
|  | 
 | ||||||
|  | Very little effort is made to make the project compatible with non-POSIX | ||||||
|  | systems. | ||||||
|  | 
 | ||||||
|  | Options | ||||||
|  | ------- | ||||||
|  | --batch 1:: | ||||||
|  | 	This program makes use of batches by decoding and preparing multiple images | ||||||
|  | 	in parallel before sending them off to models. | ||||||
|  | 	Batching requires appropriate models. | ||||||
|  | --cpu:: | ||||||
|  | 	Force CPU inference, which is usually extremely slow. | ||||||
|  | --debug:: | ||||||
|  | 	Increase verbosity. | ||||||
|  | --options "CUDAExecutionProvider;device_id=0":: | ||||||
|  | 	Set various ONNX Runtime execution provider options. | ||||||
|  | --pipe:: | ||||||
|  | 	Take input filenames from the standard input. | ||||||
|  | --threshold 0.1:: | ||||||
|  | 	Output weight threshold.  Needs to be set very high on ML-Danbooru models. | ||||||
|  | 
 | ||||||
|  | Model benchmarks | ||||||
|  | ---------------- | ||||||
|  | These were measured on a machine with GeForce RTX 4090 (24G), | ||||||
|  | and Ryzen 9 7950X3D (32 threads), on a sample of 704 images, | ||||||
|  | which took over eight hours. | ||||||
|  | 
 | ||||||
|  | There is room for further performance tuning. | ||||||
|  | 
 | ||||||
|  | GPU inference | ||||||
|  | ~~~~~~~~~~~~~ | ||||||
|  | [cols="<,>,>", options=header] | ||||||
|  | |=== | ||||||
|  | |Model|Batch size|Time | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|16|OOM | ||||||
|  | |WD v1.4 ViT v2 (batch)|16|19 s | ||||||
|  | |DeepDanbooru|16|21 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|16|21 s | ||||||
|  | |WD v1.4 ViT v2 (batch)|4|27 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|4|30 s | ||||||
|  | |DeepDanbooru|4|31 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|16|31 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|16|31 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|16|32 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|16|36 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|4|39 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|4|39 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|4|39 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|4|43 s | ||||||
|  | |WD v1.4 ViT v2|1|43 s | ||||||
|  | |WD v1.4 ViT v2 (batch)|1|43 s | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|4|48 s | ||||||
|  | |DeepDanbooru|1|53 s | ||||||
|  | |WD v1.4 MOAT v2|1|53 s | ||||||
|  | |WD v1.4 ConvNeXT v2|1|54 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|1|54 s | ||||||
|  | |WD v1.4 SwinV2 v2|1|54 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|1|54 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|1|56 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2|1|56 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|1|58 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|1|58 s | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|1|73 s | ||||||
|  | |=== | ||||||
|  | 
 | ||||||
|  | CPU inference | ||||||
|  | ~~~~~~~~~~~~~ | ||||||
|  | [cols="<,>,>", options=header] | ||||||
|  | |=== | ||||||
|  | |Model|Batch size|Time | ||||||
|  | |DeepDanbooru|16|45 s | ||||||
|  | |DeepDanbooru|4|54 s | ||||||
|  | |DeepDanbooru|1|88 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|4|139 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|16|162 s | ||||||
|  | |ML-Danbooru TResNet-D 6-30000|1|167 s | ||||||
|  | |WD v1.4 ConvNeXT v2|1|208 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|4|226 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|16|238 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2|1|245 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|4|268 s | ||||||
|  | |WD v1.4 ViT v2 (batch)|16|270 s | ||||||
|  | |WD v1.4 ConvNeXT v2 (batch)|1|272 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|4|277 s | ||||||
|  | |WD v1.4 ViT v2 (batch)|4|277 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|16|294 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|1|300 s | ||||||
|  | |WD v1.4 SwinV2 v2|1|302 s | ||||||
|  | |WD v1.4 SwinV2 v2 (batch)|16|305 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|4|307 s | ||||||
|  | |WD v1.4 ViT v2|1|308 s | ||||||
|  | |WD v1.4 ViT v2 (batch)|1|311 s | ||||||
|  | |WD v1.4 ConvNeXTV2 v2 (batch)|1|312 s | ||||||
|  | |WD v1.4 MOAT v2|1|332 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|16|335 s | ||||||
|  | |WD v1.4 MOAT v2 (batch)|1|339 s | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|4|637 s | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|16|689 s | ||||||
|  | |ML-Danbooru Caformer dec-5-97527|1|829 s | ||||||
|  | |=== | ||||||
							
								
								
									
										51
									
								
								deeptagger/bench-interpret.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										51
									
								
								deeptagger/bench-interpret.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,51 @@ | |||||||
|  | #!/bin/sh -e | ||||||
|  | parse() { | ||||||
|  | 	awk 'BEGIN { | ||||||
|  | 		OFS = FS = "\t" | ||||||
|  | 	} { | ||||||
|  | 		name = $1 | ||||||
|  | 		path = $2 | ||||||
|  | 		cpu = $3 != "" | ||||||
|  | 		batch = $4 | ||||||
|  | 		time = $5 | ||||||
|  | 
 | ||||||
|  | 		if (path ~ "/batch-") | ||||||
|  | 			name = name " (batch)" | ||||||
|  | 		else if (name ~ /^WD / && batch > 1) | ||||||
|  | 			next | ||||||
|  | 	} { | ||||||
|  | 		group = name FS cpu FS batch | ||||||
|  | 		if (lastgroup != group) { | ||||||
|  | 			if (lastgroup) | ||||||
|  | 				print lastgroup, mintime | ||||||
|  | 
 | ||||||
|  | 			lastgroup = group | ||||||
|  | 			mintime = time | ||||||
|  | 		} else { | ||||||
|  | 			if (mintime > time) | ||||||
|  | 				mintime = time | ||||||
|  | 		} | ||||||
|  | 	} END { | ||||||
|  | 		print lastgroup, mintime | ||||||
|  | 	}' "${BENCH_LOG:-bench.out}" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | cat <<END | ||||||
|  | GPU inference | ||||||
|  | ~~~~~~~~~~~~~ | ||||||
|  | [cols="<,>,>", options=header] | ||||||
|  | |=== | ||||||
|  | |Model|Batch size|Time | ||||||
|  | $(parse | awk -F'\t' 'BEGIN { OFS = "|" } | ||||||
|  | 	!$2 { print "", $1, $3, $4 " s" }' | sort -t'|' -nk4) | ||||||
|  | |=== | ||||||
|  | 
 | ||||||
|  | CPU inference | ||||||
|  | ~~~~~~~~~~~~~ | ||||||
|  | [cols="<,>,>", options=header] | ||||||
|  | |=== | ||||||
|  | |Model|Batch size|Time | ||||||
|  | $(parse | awk -F'\t' 'BEGIN { OFS = "|" } | ||||||
|  | 	$2 { print "", $1, $3, $4 " s" }' | sort -t'|' -nk4) | ||||||
|  | |=== | ||||||
|  | END | ||||||
							
								
								
									
										38
									
								
								deeptagger/bench.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										38
									
								
								deeptagger/bench.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,38 @@ | |||||||
|  | #!/bin/sh -e | ||||||
|  | if [ $# -lt 2 ] || ! [ -x "$1" ] | ||||||
|  | then | ||||||
|  | 	echo "Usage: $0 DEEPTAGGER FILE..." | ||||||
|  | 	echo "Run this after using download.sh, from the same directory." | ||||||
|  | 	exit 1 | ||||||
|  | fi | ||||||
|  | 
 | ||||||
|  | runner=$1 | ||||||
|  | shift | ||||||
|  | log=bench.out | ||||||
|  | : >$log | ||||||
|  | 
 | ||||||
|  | run() { | ||||||
|  | 	opts=$1 batch=$2 model=$3 | ||||||
|  | 	shift 3 | ||||||
|  | 
 | ||||||
|  | 	for i in $(seq 1 3) | ||||||
|  | 	do | ||||||
|  | 		start=$(date +%s) | ||||||
|  | 		"$runner" $opts -b "$batch" -t 0.75 "$model" "$@" >/dev/null || : | ||||||
|  | 		end=$(date +%s) | ||||||
|  | 		printf '%s\t%s\t%s\t%s\t%s\n' \ | ||||||
|  | 			"$name" "$model" "$opts" "$batch" "$((end - start))" | tee -a $log | ||||||
|  | 	done | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | for model in models/*.model | ||||||
|  | do | ||||||
|  | 	name=$(sed -n 's/^name=//p' "$model") | ||||||
|  | 	run ""     1 "$model" "$@" | ||||||
|  | 	run ""     4 "$model" "$@" | ||||||
|  | 	run ""    16 "$model" "$@" | ||||||
|  | 
 | ||||||
|  | 	run --cpu  1 "$model" "$@" | ||||||
|  | 	run --cpu  4 "$model" "$@" | ||||||
|  | 	run --cpu 16 "$model" "$@" | ||||||
|  | done | ||||||
							
								
								
									
										744
									
								
								deeptagger/deeptagger.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										744
									
								
								deeptagger/deeptagger.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,744 @@ | |||||||
|  | #include <getopt.h> | ||||||
|  | #include <Magick++.h> | ||||||
|  | #include <onnxruntime_cxx_api.h> | ||||||
|  | #ifdef __APPLE__ | ||||||
|  | #include <coreml_provider_factory.h> | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #include <algorithm> | ||||||
|  | #include <condition_variable> | ||||||
|  | #include <filesystem> | ||||||
|  | #include <fstream> | ||||||
|  | #include <iostream> | ||||||
|  | #include <mutex> | ||||||
|  | #include <queue> | ||||||
|  | #include <regex> | ||||||
|  | #include <set> | ||||||
|  | #include <stdexcept> | ||||||
|  | #include <string> | ||||||
|  | #include <thread> | ||||||
|  | #include <tuple> | ||||||
|  | 
 | ||||||
|  | #include <cstdio> | ||||||
|  | #include <cstdint> | ||||||
|  | #include <climits> | ||||||
|  | 
 | ||||||
|  | static struct { | ||||||
|  | 	bool cpu = false; | ||||||
|  | 	int debug = 0; | ||||||
|  | 	long batch = 1; | ||||||
|  | 	float threshold = 0.1; | ||||||
|  | 
 | ||||||
|  | 	// Execution provider name → Key → Value
 | ||||||
|  | 	std::map<std::string, std::map<std::string, std::string>> options; | ||||||
|  | } g; | ||||||
|  | 
 | ||||||
|  | // --- Configuration -----------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | // Arguably, input normalization could be incorporated into models instead.
 | ||||||
|  | struct Config { | ||||||
|  | 	std::string name; | ||||||
|  | 	enum class Shape {NHWC, NCHW} shape = Shape::NHWC; | ||||||
|  | 	enum class Channels {RGB, BGR} channels = Channels::RGB; | ||||||
|  | 	bool normalize = false; | ||||||
|  | 	enum class Pad {WHITE, EDGE, STRETCH} pad = Pad::WHITE; | ||||||
|  | 	int size = -1; | ||||||
|  | 	bool sigmoid = false; | ||||||
|  | 
 | ||||||
|  | 	std::vector<std::string> tags; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | read_tags(const std::string &path, std::vector<std::string> &tags) | ||||||
|  | { | ||||||
|  | 	std::ifstream f(path); | ||||||
|  | 	f.exceptions(std::ifstream::badbit); | ||||||
|  | 	if (!f) | ||||||
|  | 		throw std::runtime_error("cannot read tags"); | ||||||
|  | 
 | ||||||
|  | 	std::string line; | ||||||
|  | 	while (std::getline(f, line)) { | ||||||
|  | 		if (!line.empty() && line.back() == '\r') | ||||||
|  | 			line.erase(line.size() - 1); | ||||||
|  | 		tags.push_back(line); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | read_field(Config &config, std::string key, std::string value) | ||||||
|  | { | ||||||
|  | 	if (key == "name") { | ||||||
|  | 		config.name = value; | ||||||
|  | 	} else if (key == "shape") { | ||||||
|  | 		if      (value == "nhwc")    config.shape = Config::Shape::NHWC; | ||||||
|  | 		else if (value == "nchw")    config.shape = Config::Shape::NCHW; | ||||||
|  | 		else throw std::invalid_argument("bad value for: " + key); | ||||||
|  | 	} else if (key == "channels") { | ||||||
|  | 		if      (value == "rgb")     config.channels = Config::Channels::RGB; | ||||||
|  | 		else if (value == "bgr")     config.channels = Config::Channels::BGR; | ||||||
|  | 		else throw std::invalid_argument("bad value for: " + key); | ||||||
|  | 	} else if (key == "normalize") { | ||||||
|  | 		if      (value == "true")    config.normalize = true; | ||||||
|  | 		else if (value == "false")   config.normalize = false; | ||||||
|  | 		else throw std::invalid_argument("bad value for: " + key); | ||||||
|  | 	} else if (key == "pad") { | ||||||
|  | 		if      (value == "white")   config.pad = Config::Pad::WHITE; | ||||||
|  | 		else if (value == "edge")    config.pad = Config::Pad::EDGE; | ||||||
|  | 		else if (value == "stretch") config.pad = Config::Pad::STRETCH; | ||||||
|  | 		else throw std::invalid_argument("bad value for: " + key); | ||||||
|  | 	} else if (key == "size") { | ||||||
|  | 		config.size = std::stoi(value); | ||||||
|  | 	} else if (key == "interpret") { | ||||||
|  | 		if      (value == "false")   config.sigmoid = false; | ||||||
|  | 		else if (value == "sigmoid") config.sigmoid = true; | ||||||
|  | 		else throw std::invalid_argument("bad value for: " + key); | ||||||
|  | 	} else { | ||||||
|  | 		throw std::invalid_argument("unsupported config key: " + key); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | read_config(Config &config, const char *path) | ||||||
|  | { | ||||||
|  | 	std::ifstream f(path); | ||||||
|  | 	f.exceptions(std::ifstream::badbit); | ||||||
|  | 	if (!f) | ||||||
|  | 		throw std::runtime_error("cannot read configuration"); | ||||||
|  | 
 | ||||||
|  | 	std::regex re(R"(^\s*([^#=]+?)\s*=\s*([^#]*?)\s*(?:#|$))", | ||||||
|  | 		std::regex::optimize); | ||||||
|  | 	std::smatch m; | ||||||
|  | 
 | ||||||
|  | 	std::string line; | ||||||
|  | 	while (std::getline(f, line)) { | ||||||
|  | 		if (std::regex_match(line, m, re)) | ||||||
|  | 			read_field(config, m[1].str(), m[2].str()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	read_tags( | ||||||
|  | 		std::filesystem::path(path).replace_extension("tags"), config.tags); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // --- Data preparation --------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static float * | ||||||
|  | image_to_nhwc(float *data, Magick::Image &image, Config::Channels channels) | ||||||
|  | { | ||||||
|  | 	unsigned int width = image.columns(); | ||||||
|  | 	unsigned int height = image.rows(); | ||||||
|  | 
 | ||||||
|  | 	auto pixels = image.getConstPixels(0, 0, width, height); | ||||||
|  | 	switch (channels) { | ||||||
|  | 	case Config::Channels::RGB: | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) { | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) { | ||||||
|  | 				auto pixel = *pixels++; | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.red); | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.green); | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.blue); | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		break; | ||||||
|  | 	case Config::Channels::BGR: | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) { | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) { | ||||||
|  | 				auto pixel = *pixels++; | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.blue); | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.green); | ||||||
|  | 				*data++ = ScaleQuantumToChar(pixel.red); | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return data; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static float * | ||||||
|  | image_to_nchw(float *data, Magick::Image &image, Config::Channels channels) | ||||||
|  | { | ||||||
|  | 	unsigned int width = image.columns(); | ||||||
|  | 	unsigned int height = image.rows(); | ||||||
|  | 
 | ||||||
|  | 	auto pixels = image.getConstPixels(0, 0, width, height), pp = pixels; | ||||||
|  | 	switch (channels) { | ||||||
|  | 	case Config::Channels::RGB: | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).red); | ||||||
|  | 		pp = pixels; | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).green); | ||||||
|  | 		pp = pixels; | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).blue); | ||||||
|  | 		break; | ||||||
|  | 	case Config::Channels::BGR: | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).blue); | ||||||
|  | 		pp = pixels; | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).green); | ||||||
|  | 		pp = pixels; | ||||||
|  | 		for (unsigned int y = 0; y < height; y++) | ||||||
|  | 			for (unsigned int x = 0; x < width; x++) | ||||||
|  | 				*data++ = ScaleQuantumToChar((*pp++).red); | ||||||
|  | 	} | ||||||
|  | 	return data; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static Magick::Image | ||||||
|  | load(const std::string filename, | ||||||
|  | 	const Config &config, int64_t width, int64_t height) | ||||||
|  | { | ||||||
|  | 	Magick::Image image; | ||||||
|  | 	try { | ||||||
|  | 		image.read(filename); | ||||||
|  | 	} catch (const Magick::Warning &warning) { | ||||||
|  | 		if (g.debug) | ||||||
|  | 			fprintf(stderr, "%s: %s\n", filename.c_str(), warning.what()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	image.autoOrient(); | ||||||
|  | 
 | ||||||
|  | 	Magick::Geometry adjusted(width, height); | ||||||
|  | 	switch (config.pad) { | ||||||
|  | 	case Config::Pad::EDGE: | ||||||
|  | 	case Config::Pad::WHITE: | ||||||
|  | 		adjusted.greater(true); | ||||||
|  | 		break; | ||||||
|  | 	case Config::Pad::STRETCH: | ||||||
|  | 		adjusted.aspect(false); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	image.resize(adjusted, Magick::LanczosFilter); | ||||||
|  | 
 | ||||||
|  | 	// The GraphicsMagick API doesn't offer any good options.
 | ||||||
|  | 	if (config.pad == Config::Pad::EDGE) { | ||||||
|  | 		MagickLib::SetImageVirtualPixelMethod( | ||||||
|  | 			image.image(), MagickLib::EdgeVirtualPixelMethod); | ||||||
|  | 
 | ||||||
|  | 		auto x = (int64_t(image.columns()) - width) / 2; | ||||||
|  | 		auto y = (int64_t(image.rows()) - height) / 2; | ||||||
|  | 		auto source = image.getConstPixels(x, y, width, height); | ||||||
|  | 		std::vector<MagickLib::PixelPacket> | ||||||
|  | 			pixels(source, source + width * height); | ||||||
|  | 
 | ||||||
|  | 		Magick::Image edged(Magick::Geometry(width, height), "black"); | ||||||
|  | 		edged.classType(Magick::DirectClass); | ||||||
|  | 		auto target = edged.setPixels(0, 0, width, height); | ||||||
|  | 		memcpy(target, pixels.data(), pixels.size() * sizeof pixels[0]); | ||||||
|  | 		edged.syncPixels(); | ||||||
|  | 
 | ||||||
|  | 		image = edged; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Center it in a square patch of white, removing any transparency.
 | ||||||
|  | 	// image.extent() could probably be used to do the same thing.
 | ||||||
|  | 	Magick::Image white(Magick::Geometry(width, height), "white"); | ||||||
|  | 	auto x = (white.columns() - image.columns()) / 2; | ||||||
|  | 	auto y = (white.rows() - image.rows()) / 2; | ||||||
|  | 	white.composite(image, x, y, Magick::OverCompositeOp); | ||||||
|  | 	white.fileName(filename); | ||||||
|  | 
 | ||||||
|  | 	if (g.debug > 2) | ||||||
|  | 		white.display(); | ||||||
|  | 
 | ||||||
|  | 	return white; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // --- Inference ---------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | run(std::vector<Magick::Image> &images, const Config &config, | ||||||
|  | 	Ort::Session &session, std::vector<int64_t> shape) | ||||||
|  | { | ||||||
|  | 	// For consistency, this value may be bumped to always be g.batch,
 | ||||||
|  | 	// but it does not seem to have an effect on anything.
 | ||||||
|  | 	shape[0] = images.size(); | ||||||
|  | 
 | ||||||
|  | 	Ort::AllocatorWithDefaultOptions allocator; | ||||||
|  | 	auto tensor = Ort::Value::CreateTensor<float>( | ||||||
|  | 		allocator, shape.data(), shape.size()); | ||||||
|  | 
 | ||||||
|  | 	auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount(); | ||||||
|  | 	auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data; | ||||||
|  | 	for (int64_t i = 0; i < images.size(); i++) { | ||||||
|  | 		switch (config.shape) { | ||||||
|  | 		case Config::Shape::NCHW: | ||||||
|  | 			pi = image_to_nchw(pi, images.at(i), config.channels); | ||||||
|  | 			break; | ||||||
|  | 		case Config::Shape::NHWC: | ||||||
|  | 			pi = image_to_nhwc(pi, images.at(i), config.channels); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if (config.normalize) { | ||||||
|  | 		pi = input_data; | ||||||
|  | 		for (size_t i = 0; i < input_len; i++) | ||||||
|  | 			*pi++ /= 255.0; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	std::string input_name = | ||||||
|  | 		session.GetInputNameAllocated(0, allocator).get(); | ||||||
|  | 	std::string output_name = | ||||||
|  | 		session.GetOutputNameAllocated(0, allocator).get(); | ||||||
|  | 
 | ||||||
|  | 	std::vector<const char *> input_names = {input_name.c_str()}; | ||||||
|  | 	std::vector<const char *> output_names = {output_name.c_str()}; | ||||||
|  | 
 | ||||||
|  | 	auto outputs = session.Run(Ort::RunOptions{}, | ||||||
|  | 		input_names.data(), &tensor, input_names.size(), | ||||||
|  | 		output_names.data(), output_names.size()); | ||||||
|  | 	if (outputs.size() != 1 || !outputs[0].IsTensor()) { | ||||||
|  | 		fprintf(stderr, "Wrong output\n"); | ||||||
|  | 		return; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); | ||||||
|  | 	auto output_data = outputs.front().GetTensorData<float>(), po = output_data; | ||||||
|  | 	if (output_len != shape[0] * config.tags.size()) { | ||||||
|  | 		fprintf(stderr, "Tags don't match the output\n"); | ||||||
|  | 		return; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for (size_t i = 0; i < images.size(); i++) { | ||||||
|  | 		for (size_t t = 0; t < config.tags.size(); t++) { | ||||||
|  | 			float value = *po++; | ||||||
|  | 			if (config.sigmoid) | ||||||
|  | 				value = 1 / (1 + std::exp(-value)); | ||||||
|  | 			if (value > g.threshold) { | ||||||
|  | 				printf("%s\t%.2f\t%s\n", images.at(i).fileName().c_str(), | ||||||
|  | 					value, config.tags.at(t).c_str()); | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
 | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | parse_options(const std::string &options) | ||||||
|  | { | ||||||
|  | 	auto semicolon = options.find(";"); | ||||||
|  | 	auto name = options.substr(0, semicolon); | ||||||
|  | 	auto sequence = options.substr(semicolon); | ||||||
|  | 
 | ||||||
|  | 	std::map<std::string, std::string> kv; | ||||||
|  | 	std::regex re(R"(;*([^;=]+)=([^;=]+))", std::regex::optimize); | ||||||
|  | 	std::sregex_iterator it(sequence.begin(), sequence.end(), re), end; | ||||||
|  | 	for (; it != end; ++it) | ||||||
|  | 		kv[it->str(1)] = it->str(2); | ||||||
|  | 	g.options.insert_or_assign(name, std::move(kv)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static std::tuple<std::vector<const char *>, std::vector<const char *>> | ||||||
|  | unpack_options(const std::string &provider) | ||||||
|  | { | ||||||
|  | 	std::vector<const char *> keys, values; | ||||||
|  | 	if (g.options.count(provider)) { | ||||||
|  | 		for (const auto &kv : g.options.at(provider)) { | ||||||
|  | 			keys.push_back(kv.first.c_str()); | ||||||
|  | 			values.push_back(kv.second.c_str()); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return {keys, values}; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | add_providers(Ort::SessionOptions &options) | ||||||
|  | { | ||||||
|  | 	auto api = Ort::GetApi(); | ||||||
|  | 	auto v_providers = Ort::GetAvailableProviders(); | ||||||
|  | 	std::set<std::string> providers(v_providers.begin(), v_providers.end()); | ||||||
|  | 
 | ||||||
|  | 	if (g.debug) { | ||||||
|  | 		printf("Providers:"); | ||||||
|  | 		for (const auto &it : providers) | ||||||
|  | 			printf(" %s", it.c_str()); | ||||||
|  | 		printf("\n"); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// There is a string-based AppendExecutionProvider() method,
 | ||||||
|  | 	// but it cannot be used with all providers.
 | ||||||
|  | 	// TODO: Make it possible to disable providers.
 | ||||||
|  | 	// TODO: Providers will deserve some performance tuning.
 | ||||||
|  | 
 | ||||||
|  | 	if (g.cpu) | ||||||
|  | 		return; | ||||||
|  | 
 | ||||||
|  | #ifdef __APPLE__ | ||||||
|  | 	if (providers.count("CoreMLExecutionProvider")) { | ||||||
|  | 		try { | ||||||
|  | 			Ort::ThrowOnError( | ||||||
|  | 				OrtSessionOptionsAppendExecutionProvider_CoreML(options, 0)); | ||||||
|  | 		} catch (const std::exception &e) { | ||||||
|  | 			fprintf(stderr, "CoreML unavailable: %s\n", e.what()); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #if TENSORRT | ||||||
|  | 	// TensorRT should be the more performant execution provider, however:
 | ||||||
|  | 	//  - it is difficult to set up (needs logging in to download),
 | ||||||
|  | 	//  - with WD v1.4 ONNX models, one gets "Your ONNX model has been generated
 | ||||||
|  | 	//    with INT64 weights, while TensorRT does not natively support INT64.
 | ||||||
|  | 	//    Attempting to cast down to INT32." and that's not nice.
 | ||||||
|  | 	if (providers.count("TensorrtExecutionProvider")) { | ||||||
|  | 		OrtTensorRTProviderOptionsV2* tensorrt_options = nullptr; | ||||||
|  | 		Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); | ||||||
|  | 		auto [keys, values] = unpack_options("TensorrtExecutionProvider"); | ||||||
|  | 		if (!keys.empty()) { | ||||||
|  | 			Ort::ThrowOnError(api.UpdateTensorRTProviderOptions( | ||||||
|  | 				tensorrt_options, keys.data(), values.data(), keys.size())); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		try { | ||||||
|  | 			options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); | ||||||
|  | 		} catch (const std::exception &e) { | ||||||
|  | 			fprintf(stderr, "TensorRT unavailable: %s\n", e.what()); | ||||||
|  | 		} | ||||||
|  | 		api.ReleaseTensorRTProviderOptions(tensorrt_options); | ||||||
|  | 	} | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | 	// See CUDA-ExecutionProvider.html for documentation.
 | ||||||
|  | 	if (providers.count("CUDAExecutionProvider")) { | ||||||
|  | 		OrtCUDAProviderOptionsV2* cuda_options = nullptr; | ||||||
|  | 		Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); | ||||||
|  | 		auto [keys, values] = unpack_options("CUDAExecutionProvider"); | ||||||
|  | 		if (!keys.empty()) { | ||||||
|  | 			Ort::ThrowOnError(api.UpdateCUDAProviderOptions( | ||||||
|  | 				cuda_options, keys.data(), values.data(), keys.size())); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		try { | ||||||
|  | 			options.AppendExecutionProvider_CUDA_V2(*cuda_options); | ||||||
|  | 		} catch (const std::exception &e) { | ||||||
|  | 			fprintf(stderr, "CUDA unavailable: %s\n", e.what()); | ||||||
|  | 		} | ||||||
|  | 		api.ReleaseCUDAProviderOptions(cuda_options); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if (providers.count("ROCMExecutionProvider")) { | ||||||
|  | 		OrtROCMProviderOptions rocm_options = {}; | ||||||
|  | 		auto [keys, values] = unpack_options("ROCMExecutionProvider"); | ||||||
|  | 		if (!keys.empty()) { | ||||||
|  | 			Ort::ThrowOnError(api.UpdateROCMProviderOptions( | ||||||
|  | 				&rocm_options, keys.data(), values.data(), keys.size())); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		try { | ||||||
|  | 			options.AppendExecutionProvider_ROCM(rocm_options); | ||||||
|  | 		} catch (const std::exception &e) { | ||||||
|  | 			fprintf(stderr, "ROCM unavailable: %s\n", e.what()); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// The CPU provider is the default fallback, if everything else fails.
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
 | ||||||
|  | 
 | ||||||
|  | struct Thumbnailing { | ||||||
|  | 	std::mutex input_mutex; | ||||||
|  | 	std::condition_variable input_cv; | ||||||
|  | 	std::queue<std::string> input;      // All input paths
 | ||||||
|  | 	int work = 0;                       // Number of images requested
 | ||||||
|  | 
 | ||||||
|  | 	std::mutex output_mutex; | ||||||
|  | 	std::condition_variable output_cv; | ||||||
|  | 	std::vector<Magick::Image> output;  // Processed images
 | ||||||
|  | 	int done = 0;                       // Finished worker threads
 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | thumbnail(const Config &config, int64_t width, int64_t height, | ||||||
|  | 	Thumbnailing &ctx) | ||||||
|  | { | ||||||
|  | 	while (true) { | ||||||
|  | 		std::unique_lock<std::mutex> input_lock(ctx.input_mutex); | ||||||
|  | 		ctx.input_cv.wait(input_lock, | ||||||
|  | 			[&]{ return ctx.input.empty() || ctx.work; }); | ||||||
|  | 		if (ctx.input.empty()) | ||||||
|  | 			break; | ||||||
|  | 
 | ||||||
|  | 		auto path = ctx.input.front(); | ||||||
|  | 		ctx.input.pop(); | ||||||
|  | 		ctx.work--; | ||||||
|  | 		input_lock.unlock(); | ||||||
|  | 
 | ||||||
|  | 		Magick::Image image; | ||||||
|  | 		try { | ||||||
|  | 			image = load(path, config, width, height); | ||||||
|  | 			if (height != image.rows() || width != image.columns()) | ||||||
|  | 				throw std::runtime_error("tensor mismatch"); | ||||||
|  | 
 | ||||||
|  | 			std::unique_lock<std::mutex> output_lock(ctx.output_mutex); | ||||||
|  | 			ctx.output.push_back(image); | ||||||
|  | 			output_lock.unlock(); | ||||||
|  | 			ctx.output_cv.notify_all(); | ||||||
|  | 		} catch (const std::exception &e) { | ||||||
|  | 			fprintf(stderr, "%s: %s\n", path.c_str(), e.what()); | ||||||
|  | 
 | ||||||
|  | 			std::unique_lock<std::mutex> input_lock(ctx.input_mutex); | ||||||
|  | 			ctx.work++; | ||||||
|  | 			input_lock.unlock(); | ||||||
|  | 			ctx.input_cv.notify_all(); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	std::unique_lock<std::mutex> output_lock(ctx.output_mutex); | ||||||
|  | 	ctx.done++; | ||||||
|  | 	output_lock.unlock(); | ||||||
|  | 	ctx.output_cv.notify_all(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
 | ||||||
|  | 
 | ||||||
|  | static std::string | ||||||
|  | print_shape(const Ort::ConstTensorTypeAndShapeInfo &info) | ||||||
|  | { | ||||||
|  | 	std::vector<const char *> names(info.GetDimensionsCount()); | ||||||
|  | 	info.GetSymbolicDimensions(names.data(), names.size()); | ||||||
|  | 
 | ||||||
|  | 	auto shape = info.GetShape(); | ||||||
|  | 	std::string result; | ||||||
|  | 	for (size_t i = 0; i < shape.size(); i++) { | ||||||
|  | 		if (shape[i] < 0) | ||||||
|  | 			result.append(names.at(i)); | ||||||
|  | 		else | ||||||
|  | 			result.append(std::to_string(shape[i])); | ||||||
|  | 		result.append(" x "); | ||||||
|  | 	} | ||||||
|  | 	if (!result.empty()) | ||||||
|  | 		result.erase(result.size() - 3); | ||||||
|  | 	return result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | print_shapes(const Ort::Session &session) | ||||||
|  | { | ||||||
|  | 	Ort::AllocatorWithDefaultOptions allocator; | ||||||
|  | 	for (size_t i = 0; i < session.GetInputCount(); i++) { | ||||||
|  | 		std::string name = session.GetInputNameAllocated(i, allocator).get(); | ||||||
|  | 		auto info = session.GetInputTypeInfo(i); | ||||||
|  | 		auto shape = print_shape(info.GetTensorTypeAndShapeInfo()); | ||||||
|  | 		printf("Input: %s: %s\n", name.c_str(), shape.c_str()); | ||||||
|  | 	} | ||||||
|  | 	for (size_t i = 0; i < session.GetOutputCount(); i++) { | ||||||
|  | 		std::string name = session.GetOutputNameAllocated(i, allocator).get(); | ||||||
|  | 		auto info = session.GetOutputTypeInfo(i); | ||||||
|  | 		auto shape = print_shape(info.GetTensorTypeAndShapeInfo()); | ||||||
|  | 		printf("Output: %s: %s\n", name.c_str(), shape.c_str()); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void | ||||||
|  | infer(Ort::Env &env, const char *path, const std::vector<std::string> &images) | ||||||
|  | { | ||||||
|  | 	Config config; | ||||||
|  | 	read_config(config, path); | ||||||
|  | 
 | ||||||
|  | 	Ort::SessionOptions session_options; | ||||||
|  | 	add_providers(session_options); | ||||||
|  | 
 | ||||||
|  | 	Ort::Session session = Ort::Session(env, | ||||||
|  | 		std::filesystem::path(path).replace_extension("onnx").c_str(), | ||||||
|  | 		session_options); | ||||||
|  | 
 | ||||||
|  | 	if (g.debug) | ||||||
|  | 		print_shapes(session); | ||||||
|  | 
 | ||||||
|  | 	if (session.GetInputCount() != 1 || session.GetOutputCount() != 1) { | ||||||
|  | 		fprintf(stderr, "Invalid input or output shape\n"); | ||||||
|  | 		exit(EXIT_FAILURE); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	auto input_info = session.GetInputTypeInfo(0); | ||||||
|  | 	auto shape = input_info.GetTensorTypeAndShapeInfo().GetShape(); | ||||||
|  | 	if (shape.size() != 4) { | ||||||
|  | 		fprintf(stderr, "Incompatible input tensor format\n"); | ||||||
|  | 		exit(EXIT_FAILURE); | ||||||
|  | 	} | ||||||
|  | 	if (shape.at(0) > 1) { | ||||||
|  | 		fprintf(stderr, "Fixed batching not supported\n"); | ||||||
|  | 		exit(EXIT_FAILURE); | ||||||
|  | 	} | ||||||
|  | 	if (shape.at(0) >= 0 && g.batch > 1) { | ||||||
|  | 		fprintf(stderr, "Requested batching for a non-batching model\n"); | ||||||
|  | 		exit(EXIT_FAILURE); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	int64_t *height = {}, *width = {}, *channels = {}; | ||||||
|  | 	switch (config.shape) { | ||||||
|  | 	case Config::Shape::NCHW: | ||||||
|  | 		channels = &shape[1]; | ||||||
|  | 		height = &shape[2]; | ||||||
|  | 		width = &shape[3]; | ||||||
|  | 		break; | ||||||
|  | 	case Config::Shape::NHWC: | ||||||
|  | 		height = &shape[1]; | ||||||
|  | 		width = &shape[2]; | ||||||
|  | 		channels = &shape[3]; | ||||||
|  | 		break; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Variable dimensions don't combine well with batches.
 | ||||||
|  | 	if (*height < 0) | ||||||
|  | 		*height = config.size; | ||||||
|  | 	if (*width < 0) | ||||||
|  | 		*width = config.size; | ||||||
|  | 	if (*channels != 3 || *height < 1 || *width < 1) { | ||||||
|  | 		fprintf(stderr, "Incompatible input tensor format\n"); | ||||||
|  | 		return; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// By only parallelizing image loads here during batching,
 | ||||||
|  | 	// they never compete for CPU time with inference.
 | ||||||
|  | 	Thumbnailing ctx; | ||||||
|  | 	for (const auto &path : images) | ||||||
|  | 		ctx.input.push(path); | ||||||
|  | 
 | ||||||
|  | 	auto workers = g.batch; | ||||||
|  | 	if (auto threads = std::thread::hardware_concurrency()) | ||||||
|  | 		workers = std::min(workers, long(threads)); | ||||||
|  | 	for (auto i = workers; i--; ) | ||||||
|  | 		std::thread(thumbnail, std::ref(config), *width, *height, | ||||||
|  | 			std::ref(ctx)).detach(); | ||||||
|  | 
 | ||||||
|  | 	while (true) { | ||||||
|  | 		std::unique_lock<std::mutex> input_lock(ctx.input_mutex); | ||||||
|  | 		ctx.work = g.batch; | ||||||
|  | 		input_lock.unlock(); | ||||||
|  | 		ctx.input_cv.notify_all(); | ||||||
|  | 
 | ||||||
|  | 		std::unique_lock<std::mutex> output_lock(ctx.output_mutex); | ||||||
|  | 		ctx.output_cv.wait(output_lock, | ||||||
|  | 			[&]{ return ctx.output.size() == g.batch || ctx.done == workers; }); | ||||||
|  | 
 | ||||||
|  | 		if (!ctx.output.empty()) { | ||||||
|  | 			run(ctx.output, config, session, shape); | ||||||
|  | 			ctx.output.clear(); | ||||||
|  | 		} | ||||||
|  | 		if (ctx.done == workers) | ||||||
|  | 			break; | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | int | ||||||
|  | main(int argc, char *argv[]) | ||||||
|  | { | ||||||
|  | 	auto invocation_name = argv[0]; | ||||||
|  | 	auto print_usage = [=] { | ||||||
|  | 		fprintf(stderr, | ||||||
|  | 			"Usage: %s [-b BATCH] [--cpu] [-d] [-o EP;KEY=VALUE...] " | ||||||
|  | 			"[-t THRESHOLD] MODEL { --pipe | [IMAGE...] }\n", invocation_name); | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	static option opts[] = { | ||||||
|  | 		{"batch", required_argument, 0, 'b'}, | ||||||
|  | 		{"cpu", no_argument, 0, 'c'}, | ||||||
|  | 		{"debug", no_argument, 0, 'd'}, | ||||||
|  | 		{"help", no_argument, 0, 'h'}, | ||||||
|  | 		{"options", required_argument, 0, 'o'}, | ||||||
|  | 		{"pipe", no_argument, 0, 'p'}, | ||||||
|  | 		{"threshold", required_argument, 0, 't'}, | ||||||
|  | 		{nullptr, 0, 0, 0}, | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	bool pipe = false; | ||||||
|  | 	while (1) { | ||||||
|  | 		int option_index = 0; | ||||||
|  | 		auto c = getopt_long(argc, const_cast<char *const *>(argv), | ||||||
|  | 			"b:cdho:pt:", opts, &option_index); | ||||||
|  | 		if (c == -1) | ||||||
|  | 			break; | ||||||
|  | 
 | ||||||
|  | 		char *end = nullptr; | ||||||
|  | 		switch (c) { | ||||||
|  | 		case 'b': | ||||||
|  | 			errno = 0, g.batch = strtol(optarg, &end, 10); | ||||||
|  | 			if (errno || *end || g.batch < 1 || g.batch > SHRT_MAX) { | ||||||
|  | 				fprintf(stderr, "Batch size must be a positive number\n"); | ||||||
|  | 				exit(EXIT_FAILURE); | ||||||
|  | 			} | ||||||
|  | 			break; | ||||||
|  | 		case 'c': | ||||||
|  | 			g.cpu = true; | ||||||
|  | 			break; | ||||||
|  | 		case 'd': | ||||||
|  | 			g.debug++; | ||||||
|  | 			break; | ||||||
|  | 		case 'h': | ||||||
|  | 			print_usage(); | ||||||
|  | 			return 0; | ||||||
|  | 		case 'o': | ||||||
|  | 			parse_options(optarg); | ||||||
|  | 			break; | ||||||
|  | 		case 'p': | ||||||
|  | 			pipe = true; | ||||||
|  | 			break; | ||||||
|  | 		case 't': | ||||||
|  | 			errno = 0, g.threshold = strtod(optarg, &end); | ||||||
|  | 			if (errno || *end || !std::isfinite(g.threshold) || | ||||||
|  | 				g.threshold < 0 || g.threshold > 1) { | ||||||
|  | 				fprintf(stderr, "Threshold must be a number within 0..1\n"); | ||||||
|  | 				exit(EXIT_FAILURE); | ||||||
|  | 			} | ||||||
|  | 			break; | ||||||
|  | 		default: | ||||||
|  | 			print_usage(); | ||||||
|  | 			return 1; | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	argv += optind; | ||||||
|  | 	argc -= optind; | ||||||
|  | 
 | ||||||
|  | 	// TODO: There's actually no need to slurp all the lines up front.
 | ||||||
|  | 	std::vector<std::string> paths; | ||||||
|  | 	if (pipe) { | ||||||
|  | 		if (argc != 1) { | ||||||
|  | 			print_usage(); | ||||||
|  | 			return 1; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		std::string line; | ||||||
|  | 		while (std::getline(std::cin, line)) | ||||||
|  | 			paths.push_back(line); | ||||||
|  | 	} else { | ||||||
|  | 		if (argc < 1) { | ||||||
|  | 			print_usage(); | ||||||
|  | 			return 1; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		paths.assign(argv + 1, argv + argc); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Load batched images in parallel (the first is for GM, the other for IM).
 | ||||||
|  | 	if (g.batch > 1) { | ||||||
|  | 		auto value = std::to_string( | ||||||
|  | 			std::max(std::thread::hardware_concurrency() / g.batch, 1L)); | ||||||
|  | 		setenv("OMP_NUM_THREADS", value.c_str(), true); | ||||||
|  | 		setenv("MAGICK_THREAD_LIMIT", value.c_str(), true); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// XXX: GraphicsMagick initializes signal handlers here,
 | ||||||
|  | 	// one needs to use MagickLib::InitializeMagickEx()
 | ||||||
|  | 	// with MAGICK_OPT_NO_SIGNAL_HANDER to prevent that.
 | ||||||
|  | 	//
 | ||||||
|  | 	// ImageMagick conveniently has the opposite default.
 | ||||||
|  | 	Magick::InitializeMagick(nullptr); | ||||||
|  | 
 | ||||||
|  | 	OrtLoggingLevel logging = g.debug > 1 | ||||||
|  | 		? ORT_LOGGING_LEVEL_VERBOSE | ||||||
|  | 		: ORT_LOGGING_LEVEL_WARNING; | ||||||
|  | 
 | ||||||
|  | 	// Creating an environment before initializing providers in order to avoid:
 | ||||||
|  | 	// "Attempt to use DefaultLogger but none has been registered."
 | ||||||
|  | 	Ort::Env env(logging, invocation_name); | ||||||
|  | 	infer(env, argv[0], paths); | ||||||
|  | 	return 0; | ||||||
|  | } | ||||||
							
								
								
									
										161
									
								
								deeptagger/download.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										161
									
								
								deeptagger/download.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,161 @@ | |||||||
|  | #!/bin/sh -e | ||||||
|  | # Requirements: Python ~ 3.11, curl, unzip, git-lfs, awk | ||||||
|  | # | ||||||
|  | # This script downloads a bunch of models into the models/ directory, | ||||||
|  | # after any necessary transformations to run them using the deeptagger binary. | ||||||
|  | # | ||||||
|  | # Once it succeeds, feel free to remove everything but *.{model,tags,onnx} | ||||||
|  | git lfs install | ||||||
|  | mkdir -p models | ||||||
|  | cd models | ||||||
|  | 
 | ||||||
|  | # Create a virtual environment for model conversion. | ||||||
|  | # | ||||||
|  | # If any of the Python stuff fails, | ||||||
|  | # retry from within a Conda environment with a different version of Python. | ||||||
|  | export VIRTUAL_ENV=$(pwd)/venv | ||||||
|  | export TF_ENABLE_ONEDNN_OPTS=0 | ||||||
|  | if ! [ -f "$VIRTUAL_ENV/ready" ] | ||||||
|  | then | ||||||
|  | 	python3 -m venv "$VIRTUAL_ENV" | ||||||
|  | 	#"$VIRTUAL_ENV/bin/pip3" install tensorflow[and-cuda] | ||||||
|  | 	"$VIRTUAL_ENV/bin/pip3" install tf2onnx 'deepdanbooru[tensorflow]' | ||||||
|  | 	touch "$VIRTUAL_ENV/ready" | ||||||
|  | fi | ||||||
|  | 
 | ||||||
|  | status() { | ||||||
|  | 	echo "$(tput bold)-- $*$(tput sgr0)" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Using the deepdanbooru package makes it possible to use other models | ||||||
|  | # trained with the project. | ||||||
|  | deepdanbooru() { | ||||||
|  | 	local name=$1 url=$2 | ||||||
|  | 	status "$name" | ||||||
|  | 
 | ||||||
|  | 	local basename=$(basename "$url") | ||||||
|  | 	if ! [ -e "$basename" ] | ||||||
|  | 	then curl -LO "$url" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	local modelname=${basename%%.*} | ||||||
|  | 	if ! [ -d "$modelname" ] | ||||||
|  | 	then unzip -d "$modelname" "$basename" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "$modelname.tags" ] | ||||||
|  | 	then ln "$modelname/tags.txt" "$modelname.tags" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -d "$modelname.saved" ] | ||||||
|  | 	then "$VIRTUAL_ENV/bin/python3" - "$modelname" "$modelname.saved" <<-'END' | ||||||
|  | 		import sys | ||||||
|  | 		import deepdanbooru.project as ddp | ||||||
|  | 		model = ddp.load_model_from_project( | ||||||
|  | 			project_path=sys.argv[1], compile_model=False) | ||||||
|  | 		model.export(sys.argv[2]) | ||||||
|  | 	END | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "$modelname.onnx" ] | ||||||
|  | 	then "$VIRTUAL_ENV/bin/python3" -m tf2onnx.convert \ | ||||||
|  | 		--saved-model "$modelname.saved" --output "$modelname.onnx" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	cat > "$modelname.model" <<-END | ||||||
|  | 		name=$name | ||||||
|  | 		shape=nhwc | ||||||
|  | 		channels=rgb | ||||||
|  | 		normalize=true | ||||||
|  | 		pad=edge | ||||||
|  | 	END | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # ONNX preconversions don't have a symbolic first dimension, thus doing our own. | ||||||
|  | wd14() { | ||||||
|  | 	local name=$1 repository=$2 | ||||||
|  | 	status "$name" | ||||||
|  | 
 | ||||||
|  | 	local modelname=$(basename "$repository") | ||||||
|  | 	if ! [ -d "$modelname" ] | ||||||
|  | 	then git clone "https://huggingface.co/$repository" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	# Though link the original export as well. | ||||||
|  | 	if ! [ -e "$modelname.onnx" ] | ||||||
|  | 	then ln "$modelname/model.onnx" "$modelname.onnx" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "$modelname.tags" ] | ||||||
|  | 	then awk -F, 'NR > 1 { print $2 }' "$modelname/selected_tags.csv" \ | ||||||
|  | 		> "$modelname.tags" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	cat > "$modelname.model" <<-END | ||||||
|  | 		name=$name | ||||||
|  | 		shape=nhwc | ||||||
|  | 		channels=bgr | ||||||
|  | 		normalize=false | ||||||
|  | 		pad=white | ||||||
|  | 	END | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "batch-$modelname.onnx" ] | ||||||
|  | 	then "$VIRTUAL_ENV/bin/python3" -m tf2onnx.convert \ | ||||||
|  | 		--saved-model "$modelname" --output "batch-$modelname.onnx" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "batch-$modelname.tags" ] | ||||||
|  | 	then ln "$modelname.tags" "batch-$modelname.tags" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "batch-$modelname.model" ] | ||||||
|  | 	then ln "$modelname.model" "batch-$modelname.model" | ||||||
|  | 	fi | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # These models are an undocumented mess, thus using ONNX preconversions. | ||||||
|  | mldanbooru() { | ||||||
|  | 	local name=$1 basename=$2 | ||||||
|  | 	status "$name" | ||||||
|  | 
 | ||||||
|  | 	if ! [ -d ml-danbooru-onnx ] | ||||||
|  | 	then git clone https://huggingface.co/deepghs/ml-danbooru-onnx | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	local modelname=${basename%%.*} | ||||||
|  | 	if ! [ -e "$basename" ] | ||||||
|  | 	then ln "ml-danbooru-onnx/$basename" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	if ! [ -e "$modelname.tags" ] | ||||||
|  | 	then awk -F, 'NR > 1 { print $1 }' ml-danbooru-onnx/tags.csv \ | ||||||
|  | 		> "$modelname.tags" | ||||||
|  | 	fi | ||||||
|  | 
 | ||||||
|  | 	cat > "$modelname.model" <<-END | ||||||
|  | 		name=$name | ||||||
|  | 		shape=nchw | ||||||
|  | 		channels=rgb | ||||||
|  | 		normalize=true | ||||||
|  | 		pad=stretch | ||||||
|  | 		size=640 | ||||||
|  | 		interpret=sigmoid | ||||||
|  | 	END | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | status "Downloading models, beware that git-lfs doesn't indicate progress" | ||||||
|  | 
 | ||||||
|  | deepdanbooru DeepDanbooru \ | ||||||
|  | 	'https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip' | ||||||
|  | 
 | ||||||
|  | #wd14 'WD v1.4 ViT v1'        'SmilingWolf/wd-v1-4-vit-tagger' | ||||||
|  | wd14 'WD v1.4 ViT v2'        'SmilingWolf/wd-v1-4-vit-tagger-v2' | ||||||
|  | #wd14 'WD v1.4 ConvNeXT v1'   'SmilingWolf/wd-v1-4-convnext-tagger' | ||||||
|  | wd14 'WD v1.4 ConvNeXT v2'   'SmilingWolf/wd-v1-4-convnext-tagger-v2' | ||||||
|  | wd14 'WD v1.4 ConvNeXTV2 v2' 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2' | ||||||
|  | wd14 'WD v1.4 SwinV2 v2'     'SmilingWolf/wd-v1-4-swinv2-tagger-v2' | ||||||
|  | wd14 'WD v1.4 MOAT v2'       'SmilingWolf/wd-v1-4-moat-tagger-v2' | ||||||
|  | 
 | ||||||
|  | # As suggested by author https://github.com/IrisRainbowNeko/ML-Danbooru-webui | ||||||
|  | mldanbooru 'ML-Danbooru Caformer dec-5-97527' 'ml_caformer_m36_dec-5-97527.onnx' | ||||||
|  | mldanbooru 'ML-Danbooru TResNet-D 6-30000' 'TResnet-D-FLq_ema_6-30000.onnx' | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user