Make consistent batches a simple edit
This commit is contained in:
parent
819d2d80e0
commit
8df76dbaab
|
@ -255,7 +255,9 @@ static void
|
||||||
run(std::vector<Magick::Image> &images, const Config &config,
|
run(std::vector<Magick::Image> &images, const Config &config,
|
||||||
Ort::Session &session, std::vector<int64_t> shape)
|
Ort::Session &session, std::vector<int64_t> shape)
|
||||||
{
|
{
|
||||||
auto batch = shape[0] = images.size();
|
// For consistency, this value may be bumped to always be g.batch,
|
||||||
|
// but it does not seem to have an effect on anything.
|
||||||
|
shape[0] = images.size();
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto tensor = Ort::Value::CreateTensor<float>(
|
auto tensor = Ort::Value::CreateTensor<float>(
|
||||||
|
@ -263,7 +265,7 @@ run(std::vector<Magick::Image> &images, const Config &config,
|
||||||
|
|
||||||
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
|
auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
|
||||||
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
|
auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
|
||||||
for (int64_t i = 0; i < batch; i++) {
|
for (int64_t i = 0; i < images.size(); i++) {
|
||||||
switch (config.shape) {
|
switch (config.shape) {
|
||||||
case Config::Shape::NCHW:
|
case Config::Shape::NCHW:
|
||||||
pi = image_to_nchw(pi, images.at(i), config.channels);
|
pi = image_to_nchw(pi, images.at(i), config.channels);
|
||||||
|
@ -296,12 +298,12 @@ run(std::vector<Magick::Image> &images, const Config &config,
|
||||||
|
|
||||||
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
|
auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
|
||||||
auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
|
auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
|
||||||
if (output_len != batch * config.tags.size()) {
|
if (output_len != shape[0] * config.tags.size()) {
|
||||||
fprintf(stderr, "Tags don't match the output\n");
|
fprintf(stderr, "Tags don't match the output\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < batch; i++) {
|
for (size_t i = 0; i < images.size(); i++) {
|
||||||
for (size_t t = 0; t < config.tags.size(); t++) {
|
for (size_t t = 0; t < config.tags.size(); t++) {
|
||||||
float value = *po++;
|
float value = *po++;
|
||||||
if (config.sigmoid)
|
if (config.sigmoid)
|
||||||
|
@ -616,8 +618,6 @@ infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
|
||||||
ctx.output_cv.wait(output_lock,
|
ctx.output_cv.wait(output_lock,
|
||||||
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
|
[&]{ return ctx.output.size() == g.batch || ctx.done == workers; });
|
||||||
|
|
||||||
// It would be possible to add dummy entries to the batch,
|
|
||||||
// so that the model doesn't need to be rebuilt.
|
|
||||||
if (!ctx.output.empty()) {
|
if (!ctx.output.empty()) {
|
||||||
run(ctx.output, config, session, shape);
|
run(ctx.output, config, session, shape);
|
||||||
ctx.output.clear();
|
ctx.output.clear();
|
||||||
|
|
Loading…
Reference in New Issue