Make consistent batches a simple edit
This commit is contained in:
parent
819d2d80e0
commit
8df76dbaab
|
@ -255,7 +255,9 @@ static void
|
|||
run(std::vector<Magick::Image> &images, const Config &config,
|
||||
Ort::Session &session, std::vector<int64_t> shape)
|
||||
{
|
||||
auto batch = shape[0] = images.size();
|
||||
// For consistency, this value may be bumped to always be g.batch,
|
||||
// but it does not seem to have an effect on anything.
|
||||
shape[0] = images.size();
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto tensor = Ort::Value::CreateTensor<float>(
|
||||
|
@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,
|
|||
|
||||
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
switch (config.shape) {
|
||||
case Config::Shape::NCHW:
|
||||
pi = image_to_nchw(pi, images.at(i), config.channels);
|
||||
|
@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,
|
|||
|
||||
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
|
||||
if (output_len != batch * config.tags.size()) {
|
||||
if (output_len != shape[0] * config.tags.size()) {
|
||||
fprintf(stderr, "Tags don't match the output\n");
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < batch; i++) {
|
||||
for (size_t i = 0; i < images.size(); i++) {
|
||||
for (size_t t = 0; t < config.tags.size(); t++) {
|
||||
float value = *po++;
|
||||
if (config.sigmoid)
|
||||
|
@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
|
|||
ctx.output_cv.wait(output_lock,
|
||||
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
|
||||
|
||||
// It would be possible to add dummy entries to the batch,
|
||||
// so that the model doesn't need to be rebuilt.
|
||||
if (!ctx.output.empty()) {
|
||||
run(ctx.output, config, session, shape);
|
||||
ctx.output.clear();
|
||||
|
|
Loading…
Reference in New Issue