Make consistent batches a simple edit

This commit is contained in:
Přemysl Eric Janouch 2024-01-18 09:38:46 +01:00
parent 819d2d80e0
commit 8df76dbaab
Signed by: p
GPG Key ID: A0420B94F92B9493
1 changed files with 6 additions and 6 deletions

View File

@ -255,7 +255,9 @@ static void
run(std::vector<Magick::Image> &images, const Config &config,
Ort::Session &session, std::vector<int64_t> shape)
{
auto batch = shape[0] = images.size();
// For consistency, this value may be bumped to always be g.batch,
// but it does not seem to have an effect on anything.
shape[0] = images.size();
Ort::AllocatorWithDefaultOptions allocator;
auto tensor = Ort::Value::CreateTensor<float>(
@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
for (int64_t i = 0; i < batch; i++) {
for (int64_t i = 0; i < images.size(); i++) {
switch (config.shape) {
case Config::Shape::NCHW:
pi = image_to_nchw(pi, images.at(i), config.channels);
@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
if (output_len != batch * config.tags.size()) {
if (output_len != shape[0] * config.tags.size()) {
fprintf(stderr, "Tags don't match the output\n");
return;
}
for (size_t i = 0; i < batch; i++) {
for (size_t i = 0; i < images.size(); i++) {
for (size_t t = 0; t < config.tags.size(); t++) {
float value = *po++;
if (config.sigmoid)
@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
ctx.output_cv.wait(output_lock,
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
// It would be possible to add dummy entries to the batch,
// so that the model doesn't need to be rebuilt.
if (!ctx.output.empty()) {
run(ctx.output, config, session, shape);
ctx.output.clear();