Make consistent batches a simple edit

This commit is contained in:
Přemysl Eric Janouch 2024-01-18 09:38:46 +01:00
parent 819d2d80e0
commit 8df76dbaab
Signed by: p
GPG Key ID: A0420B94F92B9493
1 changed files with 6 additions and 6 deletions

View File

@ -255,7 +255,9 @@ static void
run(std::vector<Magick::Image> &images, const Config &config, run(std::vector<Magick::Image> &images, const Config &config,
Ort::Session &session, std::vector<int64_t> shape) Ort::Session &session, std::vector<int64_t> shape)
{ {
auto batch = shape[0] = images.size(); // For consistency, this value may be bumped to always be g.batch,
// but it does not seem to have an effect on anything.
shape[0] = images.size();
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto tensor = Ort::Value::CreateTensor<float>( auto tensor = Ort::Value::CreateTensor<float>(
@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount(); auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data; auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
for (int64_t i = 0; i < batch; i++) { for (int64_t i = 0; i < images.size(); i++) {
switch (config.shape) { switch (config.shape) {
case Config::Shape::NCHW: case Config::Shape::NCHW:
pi = image_to_nchw(pi, images.at(i), config.channels); pi = image_to_nchw(pi, images.at(i), config.channels);
@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
auto output_data = outputs.front().GetTensorData<float>(), po = output_data; auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
if (output_len != batch * config.tags.size()) { if (output_len != shape[0] * config.tags.size()) {
fprintf(stderr, "Tags don't match the output\n"); fprintf(stderr, "Tags don't match the output\n");
return; return;
} }
for (size_t i = 0; i < batch; i++) { for (size_t i = 0; i < images.size(); i++) {
for (size_t t = 0; t < config.tags.size(); t++) { for (size_t t = 0; t < config.tags.size(); t++) {
float value = *po++; float value = *po++;
if (config.sigmoid) if (config.sigmoid)
@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
ctx.output_cv.wait(output_lock, ctx.output_cv.wait(output_lock,
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; }); [&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
// It would be possible to add dummy entries to the batch,
// so that the model doesn't need to be rebuilt.
if (!ctx.output.empty()) { if (!ctx.output.empty()) {
run(ctx.output, config, session, shape); run(ctx.output, config, session, shape);
ctx.output.clear(); ctx.output.clear();