Skip to content

Commit c014761

Browse files
committed
add Jinja2MultimodalChatFormatter
(cherry picked from commit 4ba212f)
1 parent 07a71ae commit c014761

File tree

2 files changed

+255
-3
lines changed

2 files changed

+255
-3
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ class ChatFormatterResponse:
198198
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
199199
added_special: bool = False
200200

201+
medias: List[Union[str, bytes, bytearray]] = None
202+
media_types: List[str] = None
203+
201204

202205
class ChatFormatter(Protocol):
203206
"""Base Protocol for a chat formatter. A chat formatter is a function that

llama_cpp/mtmd.py

Lines changed: 252 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
from ._internals import LlamaContext, LlamaBatch
77

88
import ctypes
9-
from typing import Union, List
9+
from typing import Union, List, Optional, Any, Tuple
10+
11+
import llama_cpp.llama_types as llama_types
12+
import llama_cpp.llama as llama
13+
import jinja2
14+
from jinja2.sandbox import ImmutableSandboxedEnvironment
15+
import copy
16+
import numpy as np
17+
import numpy.typing as npt
18+
import os
19+
20+
from .llama_chat_format import ChatFormatter, ChatFormatterResponse
1021

1122
class TextChunk:
1223
def __init__(self, tokens: List[int]):
@@ -77,6 +88,242 @@ def close(self):
7788
def __del__(self):
7889
self.close()
7990

91+
DEFAULT_MEDIA_MARKER = mtmd.mtmd_default_marker().decode('utf-8')
92+
93+
class Jinja2MultimodalChatFormatter(ChatFormatter):
94+
def __init__(
95+
self,
96+
template: str,
97+
eos_token: str,
98+
bos_token: str,
99+
add_generation_prompt: bool = True,
100+
stop_token_ids: Optional[List[int]] = None,
101+
placeholders: List[str] = None
102+
):
103+
"""A chat formatter that uses jinja2 templates to format the prompt."""
104+
self.template = template
105+
self.eos_token = eos_token
106+
self.bos_token = bos_token
107+
self.add_generation_prompt = add_generation_prompt
108+
self.stop_token_ids = (
109+
set(stop_token_ids) if stop_token_ids is not None else None
110+
)
111+
112+
self.chat_template = ImmutableSandboxedEnvironment(
113+
loader=jinja2.BaseLoader(),
114+
trim_blocks=True,
115+
lstrip_blocks=True
116+
).from_string(template)
117+
118+
# Placeholder mapping, mtmd_tokenize requires <__media__>
119+
self.placeholders = placeholders if placeholders else [
120+
"<|vision_start|><|image_pad|><|vision_end|>", # Qwen3-VL
121+
"<image>", # LLaVA / Yi
122+
"<image_placeholder>",# DeepSeek
123+
]
124+
125+
def __call__(
126+
self,
127+
messages: List[llama_types.ChatCompletionRequestMessage],
128+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
129+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
130+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
131+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
132+
**kwargs: Any,
133+
) -> Tuple[str, List[Union[str, bytes, bytearray]], List[str]]:
134+
def raise_exception(message: str):
135+
raise ValueError(message)
136+
137+
def strftime_now(format_string="%Y-%m-%d %H:%M:%S") -> str:
138+
"""
139+
Returns the current time formatted as a string.
140+
"""
141+
return datetime.datetime.now().strftime(format_string)
142+
143+
messages = copy.deepcopy(messages)
144+
media_urls, media_types = self.split_media(messages)
145+
medias = []
146+
147+
for url, m_type in zip(media_urls, media_types):
148+
if m_type == "video":
149+
raise ValueError("Video input is not supported yet.")
150+
151+
data = self._fetch_media(url, m_type)
152+
153+
#if m_type == "image" and isinstance(data, bytes):
154+
# data = self._compress_image(data)
155+
156+
medias.append(data)
157+
158+
prompt = self.chat_template.render(
159+
messages=messages,
160+
eos_token=self.eos_token,
161+
bos_token=self.bos_token,
162+
raise_exception=raise_exception,
163+
strftime_now=strftime_now,
164+
add_generation_prompt=self.add_generation_prompt,
165+
functions=functions,
166+
function_call=function_call,
167+
tools=tools,
168+
tool_choice=tool_choice,
169+
)
170+
171+
for p in self.placeholders:
172+
prompt = prompt.replace(p, DEFAULT_MEDIA_MARKER)
173+
174+
stopping_criteria = None
175+
if self.stop_token_ids is not None:
176+
177+
def stop_on_last_token(
178+
tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
179+
) -> bool:
180+
return tokens[-1] in self.stop_token_ids
181+
182+
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
183+
184+
return ChatFormatterResponse(
185+
prompt=prompt,
186+
stop=[self.eos_token],
187+
stopping_criteria=stopping_criteria,
188+
added_special=True,
189+
medias=medias,
190+
media_types=media_types
191+
)
192+
193+
@staticmethod
194+
def split_media(messages: List[llama_types.ChatCompletionRequestMessage]):
195+
media_urls: List[Union[str, bytes, bytearray]] = []
196+
media_types: List[str] = []
197+
198+
for message in messages:
199+
if message.get("role") != "user" or not isinstance(message.get("content"), list):
200+
continue
201+
202+
for content in message["content"]:
203+
if not (isinstance(content, dict) and "type" in content):
204+
continue
205+
206+
c_type = content["type"]
207+
if c_type == "text":
208+
continue
209+
210+
value = content[c_type]
211+
212+
if isinstance(value, dict) and "url" in value:
213+
media_urls.append(value["url"])
214+
value["url"] = DEFAULT_MEDIA_MARKER
215+
else:
216+
media_urls.append(value)
217+
content[c_type] = DEFAULT_MEDIA_MARKER
218+
219+
if c_type == "image" or c_type == "image_url":
220+
media_types.append("image")
221+
222+
elif c_type == "audio" or c_type == "audio_url":
223+
media_types.append("audio")
224+
225+
elif c_type == "video" or c_type == "video_url":
226+
media_types.append("video")
227+
228+
else:
229+
raise ValueError(f"Unsupported content type {c_type}")
230+
231+
return media_urls, media_types
232+
233+
@staticmethod
234+
def _fetch_media(media_input: Union[str, bytes], media_type: str) -> Union[str, bytes, bytearray]:
235+
"""
236+
Fetch media (audio, image, video...) from local disk, memory, or internet
237+
"""
238+
239+
# --- from_buffer fast path ---
240+
if isinstance(media_input, bytes) or isinstance(media_input, bytearray):
241+
return media_input
242+
243+
if not isinstance(media_input, str):
244+
raise ValueError(f"Unsupported media input type: {type(media_input)}")
245+
246+
# --- from_file fast path ---
247+
if media_input.startswith("file://"):
248+
parsed_path = urllib.parse.urlparse(media_input).path
249+
# unquote 处理 URL 编码的字符
250+
abs_path = os.path.abspath(urllib.parse.unquote(parsed_path))
251+
if os.path.exists(abs_path):
252+
return abs_path
253+
else:
254+
raise FileNotFoundError(f"Local file not found: {abs_path}")
255+
256+
# --- base64 or remote url ---
257+
raw_bytes = b""
258+
if media_input.startswith("data:"):
259+
import base64
260+
# Split only once from the right to correctly handle mime types containing commas
261+
comma_pos = media_input.find(",")
262+
if comma_pos == -1:
263+
raise ValueError("Invalid data URI: missing comma separator")
264+
265+
raw_bytes = base64.b64decode(media_input[comma_pos+1:])
266+
elif "://" in media_input:
267+
import urllib.request
268+
from urllib.error import URLError, HTTPError
269+
270+
headers = {"User-Agent": "Mozilla/5.0"}
271+
req = urllib.request.Request(media_input, headers=headers)
272+
273+
try:
274+
with urllib.request.urlopen(req, timeout=15) as f:
275+
raw_bytes = f.read()
276+
except (URLError, HTTPError) as e:
277+
raise ConnectionError(f"Failed to fetch media from {media_input}: {e}")
278+
279+
else:
280+
# try direct path
281+
if os.path.exists(media_input):
282+
return os.path.abspath(media_input)
283+
raise ValueError("Unrecognized media string format")
284+
285+
if not raw_bytes:
286+
raise ValueError("Empty data received")
287+
288+
return raw_bytes
289+
290+
@staticmethod
291+
def _compress_image(image_bytes: bytes) -> bytes:
292+
try:
293+
from PIL import Image, ImageStat
294+
except ImportError:
295+
raise ImportError("Pillow is required for image processing. Install with: pip install pillow")
296+
297+
import io
298+
image = Image.open(io.BytesIO(image_bytes))
299+
300+
# 4. Handle transparency (RGBA, LA, P with transparency, etc.)
301+
if image.mode in ("RGBA", "LA", "PA") or (image.mode == "P" and "transparency" in image.info):
302+
# Use alpha channel as mask
303+
if image.mode == "P":
304+
image = image.convert("RGBA")
305+
306+
alpha = image.split()[-1] # Last channel is alpha
307+
# Compute average brightness of visible (non-transparent) pixels
308+
stat = ImageStat.Stat(image.convert("L"), mask=alpha)
309+
310+
# Choose background: white for dark content, black for bright content
311+
bg_color = (255, 255, 255) # white
312+
if stat.count[0] > 0 and stat.mean[0] > 127:
313+
bg_color = (0, 0, 0) # black
314+
315+
background = Image.new("RGB", image.size, bg_color)
316+
background.paste(image, mask=alpha)
317+
image = background
318+
319+
# 5. Ensure RGB mode for formats like CMYK, palette, etc.
320+
elif image.mode != "RGB":
321+
image = image.convert("RGB")
322+
323+
# 6. Save as high-quality JPEG, suitable for most vision models.
324+
output = io.BytesIO()
325+
image.save(output, format="JPEG", quality=95, optimize=True, progressive=True)
326+
return output.getvalue()
80327

81328
# Simple FNV-1a hash implementation to match fnv_hash in C++
82329
def fnv_hash(data: bytes) -> str:
@@ -89,12 +336,12 @@ def fnv_hash(data: bytes) -> str:
89336
def mtmd_tokenize(
90337
mctx: mtmd.mtmd_context_p,
91338
prompt: str,
92-
files_data: list[bytes | str]) -> MultimodalTokenList:
339+
medias_data: list[Union[str, bytes, bytearray]]) -> MultimodalTokenList:
93340

94341
bitmaps = []
95342
do_hash = False
96343

97-
for data in files_data:
344+
for data in medias_data:
98345

99346
bmp = None
100347
if isinstance(data, str):
@@ -200,3 +447,5 @@ def mtmd_prefill(
200447
raise RuntimeError(f"MTMD eval error: {result}")
201448

202449
n_past = new_n_past.value
450+
451+
return n_past

0 commit comments

Comments
 (0)