You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description
Triton feeds model with batches of size which is neither preferred_batch_size nor max_batch_size but max_queue_delay_microseconds is not exceeded.
I try to improve the performance of my TensorRT model, which was exported from PyTorch. As is usual for a TensorRT model, all (or some of) the optimal shapes are mentioned in the preferred_batch_size section. While investigating, I noticed that some model configurations led to slow inference times. As I delved deeper, I discovered that, under any excessive workload, the Triton Inference Server tended to feed the model with batches with incorrect batch size (max(preferred_batch_size) + 1 specifically), which led to slower inference in the TensorRT case, as these sizes differed from those in the model Optimization Profiles.
It seems not to be backend related (as I have reproduced it on both LibTorch and TensorRT). It is also not frontend related as I used Locust with HTTP REST requests and system shared memory but reproduced on molotov using python grpc.aio client without shm as well. Model instances number increase/decrease changes nothing.
As I create load to ensure internal triton queue is always more than 16 but less than 256, model begins to process batches of size 17. max_queue_delay_microseconds is set on any extremely high value to be sure that is not a timeout behavior.
I used --log-verbose=2 --log-info=true --log-file=/logs/log.txt so that I may see all the batch sizes the model processes. Output of grep "executing" log.txt is below.
I always send requests of size 1 so 17 request mean batch size of 17. (Anyway, using requests of size 2 leads to the same behavior of processing 17 requests.)
I may see the same batch sizes using rate(nv_inference_request_success[15s])/rate(nv_inference_exec_count[15s]) prometheus metrics.
If I stop using use max_queue_delay_microseconds, Triton Inference Server seems to work normally. Using max_batch_size equals to max of preferred_batch_size leads to correct behavior as well. Mentioning high enough value (like equals to max_batch_size - 1) in preferred_batch_size section also prevents triton from using incorrect batch size (as previous maximum value it is not a maximum value anymore).
Triton Information
I use nvcr.io/nvidia/tritonserver:25.02-py3 but 24.02-py3 shows the same behavior.
To Reproduce
Use the model and config which are enough to reproduce. I suppose any model should fit.
Ensure to create proper load (less than max_batch_size but more than max of preferred_batch_size).
Check logs with --log-verbose=2 or Prometheus metrics for max(preferred_batch_size) + 1 value.
Expected behavior
Tritonserver does not process batches of size not mentioned in preferred_batch_size and max_batch_size unless max_queue_delay_microseconds time exceeded.
The text was updated successfully, but these errors were encountered:
Description
Triton feeds model with batches of size which is neither
preferred_batch_size
normax_batch_size
butmax_queue_delay_microseconds
is not exceeded.I try to improve the performance of my TensorRT model, which was exported from PyTorch. As is usual for a TensorRT model, all (or some of) the optimal shapes are mentioned in the
preferred_batch_size
section. While investigating, I noticed that some model configurations led to slow inference times. As I delved deeper, I discovered that, under any excessive workload, the Triton Inference Server tended to feed the model with batches with incorrect batch size (max(preferred_batch_size) + 1
specifically), which led to slower inference in the TensorRT case, as these sizes differed from those in the model Optimization Profiles.It seems not to be backend related (as I have reproduced it on both LibTorch and TensorRT). It is also not frontend related as I used Locust with HTTP REST requests and system shared memory but reproduced on molotov using python grpc.aio client without shm as well. Model instances number increase/decrease changes nothing.
(The part of) example config is listed below.
As I create load to ensure internal triton queue is always more than 16 but less than 256, model begins to process batches of size 17.
max_queue_delay_microseconds
is set on any extremely high value to be sure that is not a timeout behavior.I used
--log-verbose=2 --log-info=true --log-file=/logs/log.txt
so that I may see all the batch sizes the model processes. Output ofgrep "executing" log.txt
is below.I always send requests of size 1 so
17 request
mean batch size of 17. (Anyway, using requests of size 2 leads to the same behavior of processing 17 requests.)I may see the same batch sizes using
rate(nv_inference_request_success[15s])/rate(nv_inference_exec_count[15s])
prometheus metrics.If I stop using
use max_queue_delay_microseconds
, Triton Inference Server seems to work normally. Usingmax_batch_size
equals to max ofpreferred_batch_size
leads to correct behavior as well. Mentioning high enough value (like equals tomax_batch_size - 1
) inpreferred_batch_size
section also prevents triton from using incorrect batch size (as previous maximum value it is not a maximum value anymore).Triton Information
I use
nvcr.io/nvidia/tritonserver:25.02-py3
but24.02-py3
shows the same behavior.To Reproduce
max_batch_size
but more than max ofpreferred_batch_size
).--log-verbose=2
or Prometheus metrics formax(preferred_batch_size) + 1
value.Expected behavior
Tritonserver does not process batches of size not mentioned in
preferred_batch_size
andmax_batch_size
unlessmax_queue_delay_microseconds
time exceeded.The text was updated successfully, but these errors were encountered: