Skip to content

Commit 6ea1e6a

Browse files
Support MMMU benchmark for InternVL (sgl-project#5968)
1 parent 3409aaa commit 6ea1e6a

File tree

2 files changed

+139
-12
lines changed

2 files changed

+139
-12
lines changed

benchmark/mmmu/bench_hf.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
@torch.no_grad()
1818
def eval_mmmu(args):
1919
eval_args = EvalArgs.from_cli_args(args)
20+
21+
sampling_params = get_sampling_params(eval_args)
22+
generation_config = GenerationConfig(
23+
max_new_tokens=sampling_params["max_new_tokens"],
24+
do_sample=False,
25+
)
26+
2027
try:
2128
from transformers import AutoModelForImageTextToText
2229

@@ -27,12 +34,28 @@ def eval_mmmu(args):
2734
)
2835
except Exception as first_exception:
2936
try:
30-
model = AutoModel.from_pretrained(
31-
args.model_path,
32-
torch_dtype="auto",
33-
trust_remote_code=True,
34-
init_tts=False,
35-
)
37+
# check if the model is belongs to internvl
38+
if "InternVL" in args.model_path:
39+
from internvl_utils import load_image
40+
from transformers import AutoTokenizer
41+
42+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
43+
model = AutoModel.from_pretrained(
44+
args.model_path,
45+
torch_dtype="auto",
46+
trust_remote_code=True,
47+
)
48+
generation_config_internvl = dict(
49+
max_new_tokens=sampling_params["max_new_tokens"], do_sample=False
50+
)
51+
52+
else:
53+
model = AutoModel.from_pretrained(
54+
args.model_path,
55+
torch_dtype="auto",
56+
trust_remote_code=True,
57+
init_tts=False,
58+
)
3659
except Exception as second_exception:
3760
raise RuntimeError(
3861
f"Failed to load model: First attempt failed with {first_exception}, "
@@ -48,19 +71,29 @@ def eval_mmmu(args):
4871
samples = prepare_samples(eval_args)
4972
out_samples = dict()
5073

51-
sampling_params = get_sampling_params(eval_args)
52-
generation_config = GenerationConfig(
53-
max_new_tokens=sampling_params["max_new_tokens"],
54-
do_sample=False,
55-
)
56-
5774
answer_dict = {}
5875
for sample in tqdm(samples):
5976
prompt = sample["final_input_prompt"]
6077
image = sample["image"]
6178
prefix = prompt.split("<")[0]
6279
suffix = prompt.split(">")[1]
6380
assert image is not None
81+
82+
if "InternVL" in args.model_path:
83+
pixel_values = load_image(sample["image_path"]).to(torch.bfloat16).cuda()
84+
contents = ""
85+
if prefix:
86+
contents += prefix
87+
contents += "<image>\n"
88+
if suffix:
89+
contents += suffix
90+
response = model.chat(
91+
tokenizer, pixel_values, contents, generation_config_internvl
92+
)
93+
print(f"response: {response}")
94+
process_result(response, sample, answer_dict, out_samples)
95+
continue
96+
6497
contents = []
6598
if prefix:
6699
contents += [{"type": "text", "text": prefix}]

benchmark/mmmu/internvl_utils.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# copy from https://huggingface.co/OpenGVLab/InternVL3-1B
2+
import torch
3+
import torchvision.transforms as T
4+
from PIL import Image
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
8+
IMAGENET_STD = (0.229, 0.224, 0.225)
9+
10+
11+
def build_transform(input_size):
12+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
13+
transform = T.Compose(
14+
[
15+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
16+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
17+
T.ToTensor(),
18+
T.Normalize(mean=MEAN, std=STD),
19+
]
20+
)
21+
return transform
22+
23+
24+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
25+
best_ratio_diff = float("inf")
26+
best_ratio = (1, 1)
27+
area = width * height
28+
for ratio in target_ratios:
29+
target_aspect_ratio = ratio[0] / ratio[1]
30+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
31+
if ratio_diff < best_ratio_diff:
32+
best_ratio_diff = ratio_diff
33+
best_ratio = ratio
34+
elif ratio_diff == best_ratio_diff:
35+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
36+
best_ratio = ratio
37+
return best_ratio
38+
39+
40+
def dynamic_preprocess(
41+
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
42+
):
43+
orig_width, orig_height = image.size
44+
aspect_ratio = orig_width / orig_height
45+
46+
# calculate the existing image aspect ratio
47+
target_ratios = set(
48+
(i, j)
49+
for n in range(min_num, max_num + 1)
50+
for i in range(1, n + 1)
51+
for j in range(1, n + 1)
52+
if i * j <= max_num and i * j >= min_num
53+
)
54+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
55+
56+
# find the closest aspect ratio to the target
57+
target_aspect_ratio = find_closest_aspect_ratio(
58+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
59+
)
60+
61+
# calculate the target width and height
62+
target_width = image_size * target_aspect_ratio[0]
63+
target_height = image_size * target_aspect_ratio[1]
64+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
65+
66+
# resize the image
67+
resized_img = image.resize((target_width, target_height))
68+
processed_images = []
69+
for i in range(blocks):
70+
box = (
71+
(i % (target_width // image_size)) * image_size,
72+
(i // (target_width // image_size)) * image_size,
73+
((i % (target_width // image_size)) + 1) * image_size,
74+
((i // (target_width // image_size)) + 1) * image_size,
75+
)
76+
# split the image
77+
split_img = resized_img.crop(box)
78+
processed_images.append(split_img)
79+
assert len(processed_images) == blocks
80+
if use_thumbnail and len(processed_images) != 1:
81+
thumbnail_img = image.resize((image_size, image_size))
82+
processed_images.append(thumbnail_img)
83+
return processed_images
84+
85+
86+
def load_image(image_file, input_size=448, max_num=12):
87+
image = Image.open(image_file).convert("RGB")
88+
transform = build_transform(input_size=input_size)
89+
images = dynamic_preprocess(
90+
image, image_size=input_size, use_thumbnail=True, max_num=max_num
91+
)
92+
pixel_values = [transform(image) for image in images]
93+
pixel_values = torch.stack(pixel_values)
94+
return pixel_values

0 commit comments

Comments
 (0)