From 9a511cd88a439ec5e958d3bffdaefe52751e18a7 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Thu, 22 May 2025 13:41:27 -0700 Subject: [PATCH] fix anti-pattern for cudagraph --- torchbenchmark/models/yolov3/yolo_models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchbenchmark/models/yolov3/yolo_models.py b/torchbenchmark/models/yolov3/yolo_models.py index 667544a18a..f4e12a06ec 100755 --- a/torchbenchmark/models/yolov3/yolo_models.py +++ b/torchbenchmark/models/yolov3/yolo_models.py @@ -210,8 +210,7 @@ def forward(self, p, out): i, n = self.index, self.nl # index in layers, number of layers p = out[self.layers[i]] bs, _, ny, nx = p.shape # bs, 255, 13, 13 - if (self.nx, self.ny) != (nx, ny): - self.create_grids((nx, ny), p.device) + self.create_grids((nx, ny), p.device) # outputs and weights # w = F.softmax(p[:, -n:], 1) # normalized weights @@ -233,8 +232,7 @@ def forward(self, p, out): bs = 1 # batch size else: bs, _, ny, nx = p.shape # bs, 255, 13, 13 - if (self.nx, self.ny) != (nx, ny): - self.create_grids((nx, ny), p.device) + self.create_grids((nx, ny), p.device) # p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 85) # (bs, anchors, grid, grid, classes + xywh) p = (