Skip to content

Commit 1410ac3

Browse files
authored
Merge pull request #313 from white2018/cogvlm2
support cogvlm2 model
2 parents b1d319f + d94cd2f commit 1410ac3

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

api/adapter/patcher.py

+3
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def patch_config(
164164

165165

166166
def patch_model(model: "PreTrainedModel") -> None:
167+
if 'CogVLMForCausalLM' in model.config.architectures:
168+
model.config.model_type = "cogvlm2"
169+
return
167170
if model.config.model_type == "internvl_chat":
168171
return
169172
if model.config.model_type == "minicpmv":

api/engine/hf.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from api.templates.glm import generate_stream_chatglm, generate_stream_chatglm_v3
3838
from api.templates.minicpm import generate_stream_minicpm_v
3939
from api.templates.minimonkey import generate_stream_minimonkey
40+
from api.templates.cogvlm2 import generate_stream_cogvlm2
4041
from api.templates.stream import generate_stream
4142
from api.templates.utils import get_context_length
4243
from api.utils import create_error_response
@@ -81,6 +82,8 @@ def __init__(
8182
self.generate_stream_func = generate_stream_minicpm_v
8283
elif self.model.config.model_type == "internvl_chat":
8384
self.generate_stream_func = generate_stream_minimonkey
85+
elif self.model.config.model_type == "cogvlm2":
86+
self.generate_stream_func = generate_stream_cogvlm2
8487

8588
logger.info(f"Using {self.model_name} Model for Chat!")
8689
logger.info(f"Using {self.template} for Chat!")
@@ -101,6 +104,8 @@ def _generate(self, params: Dict[str, Any]) -> Iterator[dict]:
101104
else:
102105
if self.model.config.model_type == "minicpmv":
103106
inputs = prompt_or_messages
107+
elif self.model.config.model_type == "cogvlm2":
108+
inputs = prompt_or_messages
104109
elif self.model.config.model_type == "internvl_chat":
105110
inputs = prompt_or_messages
106111
else:

api/templates/cogvlm2.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import time
5+
import uuid
6+
from typing import (
7+
Any,
8+
Dict,
9+
List,
10+
Iterator,
11+
TYPE_CHECKING,
12+
)
13+
14+
import torch
15+
16+
from api.protocol import ChatCompletionMessageParam
17+
18+
if TYPE_CHECKING:
19+
from transformers import PreTrainedTokenizer, PreTrainedModel
20+
21+
22+
import queue
23+
from threading import Thread
24+
import torchvision.transforms as T
25+
import transformers
26+
from torchvision.transforms.functional import InterpolationMode
27+
from transformers import BitsAndBytesConfig, TextIteratorStreamer
28+
29+
transformers.logging.set_verbosity_error()
30+
31+
# THUDM/cogvlm2-llama3-chat-19B
32+
# THUDM/cogvlm2-llama3-chinese-chat-19B
33+
34+
@torch.inference_mode()
35+
def generate_stream_cogvlm2(
36+
model: "PreTrainedModel",
37+
tokenizer: "PreTrainedTokenizer",
38+
params: Dict[str, Any],
39+
) -> Iterator:
40+
"""
41+
Generates text in a streaming manner using the ChatGLM model.
42+
43+
Args:
44+
model: The pre-trained model.
45+
tokenizer: The tokenizer used for tokenizing the input.
46+
params: A dictionary containing the input parameters.
47+
48+
Yields:
49+
A dictionary representing each generated text completion.
50+
51+
"""
52+
inputs = params["inputs"]
53+
model_name = params.get("model", "llm")
54+
55+
query, history, images, system_message = prompt_history_images_system_from_messages(inputs, img_tok='')
56+
57+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=images, template_version='chat')
58+
59+
inputs = {
60+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model.device),
61+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model.device),
62+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model.device),
63+
'images': [[input_by_model['images'][0].to(model.device).to(model.dtype)]] if images else None,
64+
}
65+
66+
new_params = dict(temperature = float(params.get("temperature", 1.0)),
67+
max_new_tokens = int(params.get("max_tokens", 256)),
68+
repetition_penalty = float(params.get("repetition_penalty", 1.0)),
69+
top_p = float(params.get("top_p", 1.0)),
70+
top_k = int(params.get("top_k", 50)))
71+
72+
generation_kwargs = dict(
73+
**inputs,
74+
**new_params,
75+
)
76+
77+
input_echo_len = 0
78+
generated_text, previous_text = "", ""
79+
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
80+
created: int = int(time.time())
81+
for i, new_text in enumerate(threaded_streaming_generator(generate=model.generate, tokenizer=tokenizer, generation_kwargs=generation_kwargs)):
82+
end = new_text.find(tokenizer.eos_token)
83+
if end != -1:
84+
new_text = new_text[:end]
85+
86+
generated_text += new_text
87+
delta_text = generated_text[len(previous_text):]
88+
previous_text = generated_text
89+
yield {
90+
"id": completion_id,
91+
"object": "text_completion",
92+
"created": created,
93+
"model": model_name,
94+
"delta": delta_text,
95+
"text": generated_text,
96+
"logprobs": None,
97+
"finish_reason": None,
98+
"usage": {
99+
"prompt_tokens": input_echo_len,
100+
"completion_tokens": i,
101+
"total_tokens": input_echo_len + i,
102+
},
103+
}
104+
105+
if end != -1:
106+
break
107+
108+
gc.collect()
109+
torch.cuda.empty_cache()
110+
111+
def prompt_history_images_system_from_messages(messages: list[ChatCompletionMessageParam], img_tok = "<image>\n"):
112+
history = []
113+
images = []
114+
prompt = ''
115+
system_prompt = None
116+
117+
for m in messages:
118+
if m['role'] == 'user':
119+
p = ''
120+
for c in m['content']:
121+
if c['type'] == 'image_url':
122+
image = url_to_image(c['image_url']['url'])
123+
images.extend([image])
124+
p = img_tok + p
125+
if c['type'] == 'text':
126+
p += c['text']
127+
128+
prompt += p
129+
elif m['role'] == 'assistant':
130+
for c in m['content']:
131+
if c['type'] == 'text':
132+
history.extend([(prompt, c['text'])])
133+
prompt = ''
134+
elif m['role'] == 'system':
135+
for c in m['content']:
136+
if c['type'] == 'text':
137+
system_prompt = c['text']
138+
139+
return prompt, history, images, system_prompt
140+
141+
142+
def url_to_image(image_url: str):
143+
from PIL import Image
144+
from io import BytesIO
145+
146+
if image_url.startswith("data:"):
147+
import base64
148+
149+
image_bytes = base64.b64decode(image_url.split(",")[1])
150+
else:
151+
import urllib.request
152+
153+
with urllib.request.urlopen(image_url) as f:
154+
image_bytes = f.read()
155+
156+
return Image.open(BytesIO(image_bytes)).convert("RGB")
157+
158+
159+
def threaded_streaming_generator(generate, tokenizer, generation_kwargs):
160+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True, timeout=60)
161+
162+
generation_kwargs['streamer'] = streamer
163+
164+
exq = queue.Queue()
165+
166+
def wrapper():
167+
try:
168+
with torch.no_grad():
169+
generate(**generation_kwargs)
170+
171+
except Exception as e:
172+
#logger.exception(e)
173+
exq.put(e)
174+
streamer.end()
175+
176+
t = Thread(target=wrapper, daemon=True)
177+
t.start()
178+
179+
for text in streamer:
180+
if text:
181+
yield text
182+
183+
if not exq.empty():
184+
raise exq.get_nowait()

0 commit comments

Comments
 (0)