Skip to content

Commit 0932494

Browse files
committed
Merge branch 'trintamaki/fix-packing-test' into 'main'
Fix packed sequence unit test See merge request ADLR/megatron-lm!2548
2 parents e5793c0 + 66a3a00 commit 0932494

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

tests/unit_tests/models/test_llava_model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -320,21 +320,27 @@ def test_forward(self):
320320
# Try with labels and PackedSeqParams. Only micro batch size 1 is supported in this mode.
321321
packed_seq_params = PackedSeqParams(
322322
qkv_format="thd",
323-
cu_seqlens_q=[0, 512, 1024, 1600], # Just example values.
324-
cu_seqlens_kv=[0, 512, 1024, 1600],
325-
max_seqlen_q=[1600],
326-
max_seqlen_kv=[1600],
323+
cu_seqlens_q=torch.tensor(
324+
[0, 512, 1024, 1600], dtype=torch.int32
325+
).cuda(), # Just example values.
326+
cu_seqlens_kv=torch.tensor([0, 512, 1024, 1600], dtype=torch.int32).cuda(),
327+
max_seqlen_q=torch.tensor(1600, dtype=torch.int32).cuda(),
328+
max_seqlen_kv=torch.tensor(1600, dtype=torch.int32).cuda(),
327329
)
328330

331+
# NOTE: Packing is only supported with BF16. Use BF16 here and switch back to default.
332+
self.model.to(torch.bfloat16)
329333
loss, new_loss_mask = self.model.forward(
330-
img[:1],
334+
img[:1].to(torch.bfloat16),
331335
input_ids[:1],
332336
position_ids[:1],
333337
attention_mask,
334338
labels[:1],
335339
loss_mask[:1],
336340
num_image_tiles=num_image_tiles[:1],
341+
packed_seq_params=packed_seq_params,
337342
)
343+
self.model.to(torch.float32)
338344

339345
# 1600 = 577 (img_seq_len) + 1024 (text tokens in the first sample) - 1 (image token).
340346
assert loss.shape == new_loss_mask.shape == torch.Size((1, 1600))

0 commit comments

Comments
 (0)