import math import torch import warnings import PIL.Image from torch.nn import functional as F from collections import UserDict, OrderedDict from typing import Union, Optional, Tuple, List, Dict, Any from transformers.image_utils import load_image from transformers.feature_extraction_utils import BatchFeature from .chat_template_utils import render_jinja_template from transformers.processing_utils import ProcessorMixin, AllKwargsForChatTemplate class DMLLMProcessor(ProcessorMixin): attributes = ["tokenizer", "image_processor"] optional_attributes = ['chat_template'] model_input_names = ['input_ids', 'attention_mask', 'pixel_values'] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, tokenizer, image_processor, chat_template=None, image_size=512, patch_size=16, downsample_ratio=0.5, max_sub_img=6, min_sub_img=1, image_token='', image_start_token='', image_end_token='', special_tokens=['', '', ''], #'', '' **kwargs): if chat_template is None: chat_template = "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant.<|eot_id|>\n{% endif %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'][0]['text'] }}<|eot_id|>{% endgeneration %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% elif content['type'] == 'video' or 'video' in content %}{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|eot_id|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n{% endif %}" super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template) if isinstance(image_size, list) or isinstance(image_size, tuple): image_size = image_size[0] self.num_image_token = int((image_size // patch_size) ** 2 * (downsample_ratio ** 2)) self.vision_token_share_pe = kwargs.get('vision_token_share_pe', True) self.image_token_len = kwargs.pop('image_token_len', 256) self.max_sub_img = max_sub_img self.min_sub_img = min_sub_img self.image_token = image_token self.image_start_token = image_start_token self.image_end_token = image_end_token self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}, replace_additional_special_tokens=False) self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) self.image_start_token_id = self.tokenizer.convert_tokens_to_ids(self.image_start_token) self.image_end_token_id = self.tokenizer.convert_tokens_to_ids(self.image_end_token) if 'llada' in tokenizer.name_or_path.lower(): self._pad_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") elif 'dream' in tokenizer.name_or_path.lower(): self._pad_token_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") elif 'sdar' in tokenizer.name_or_path.lower(): self._pad_token_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") if isinstance(image_size, int): image_size = (image_size, image_size) else: image_size = image_size self.image_size = image_size assert image_size[0] == image_size[1] def apply_chat_template(self, conversation, chat_template = None, **kwargs) -> str: if chat_template is None: chat_template = self.chat_template processed_kwargs = { "mm_load_kwargs": {}, "template_kwargs": {}, } # for kwarg_type in processed_kwargs: # for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys(): # kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type] # default_value = getattr(kwarg_type_defaults, key, None) # value = kwargs.pop(key, default_value) # if value is not None and not isinstance(value, dict): # processed_kwargs[kwarg_type][key] = value # Pass unprocessed custom kwargs processed_kwargs["template_kwargs"].update(kwargs) conversations = [conversation] prompt, generation_indices = render_jinja_template( conversations=conversations, chat_template=chat_template, **processed_kwargs["template_kwargs"], # different flags such as `return_assistant_mask` **self.tokenizer.special_tokens_map, # tokenizer special tokens are used by some templates ) return prompt, generation_indices def __call__(self, text=None, images=[], videos=None, generation_indices=None, **kwargs) ->BatchFeature: inputs = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False) assistant_masks = [] input_ids = inputs["input_ids"] for i in range(len(input_ids)): current_mask = [0] * len(input_ids[i]) if 'dream' in self.tokenizer.name_or_path.lower(): # 基于 Dream 模型的标记来定位 assistant 部分 # 查找 <|im_start|>assistant 和 <|im_end|> 之间的内容 im_start_assistant_pattern = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) im_end_pattern = self.tokenizer.encode("<|im_end|>", add_special_tokens=False) # 在 input_ids 中查找 assistant 段落 j = 0 while j < len(input_ids[i]) - len(im_start_assistant_pattern) + 1: # 检查是否匹配 <|im_start|>assistant if input_ids[i][j:j+len(im_start_assistant_pattern)] == im_start_assistant_pattern: start_token = j + len(im_start_assistant_pattern) # 查找对应的 <|im_end|> end_token = None for k in range(start_token, len(input_ids[i]) - len(im_end_pattern) + 1): if input_ids[i][k:k+len(im_end_pattern)] == im_end_pattern: end_token = k break # 标记 assistant 部分 if end_token is not None: for token_idx in range(start_token, end_token + len(im_end_pattern)): current_mask[token_idx] = 1 j = end_token + len(im_end_pattern) else: j += 1 else: j += 1 elif 'llada' in self.tokenizer.name_or_path.lower(): # Skip assistant mask computation if generation_indices is None/empty (e.g., GRPO prompt-only) if generation_indices is not None and i < len(generation_indices) and generation_indices[i]: for assistant_start_char, assistant_end_char in generation_indices[i]: start_token = inputs.char_to_token(i, assistant_start_char) end_token = inputs.char_to_token(i, assistant_end_char - 1) if start_token is None: # start_token is out of bounds maybe due to truncation. break for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): current_mask[token_id] = 1 elif 'sdar' in self.tokenizer.name_or_path.lower(): # 为SDAR模型添加assistant识别逻辑 # SDAR使用 <|im_start|>assistant\n 和 <|im_end|> 格式 im_start_assistant_pattern = self.tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) im_end_pattern = self.tokenizer.encode("<|im_end|>", add_special_tokens=False) # 在 input_ids 中查找 assistant 段落 j = 0 while j < len(input_ids[i]) - len(im_start_assistant_pattern) + 1: # 检查是否匹配 <|im_start|>assistant if input_ids[i][j:j+len(im_start_assistant_pattern)] == im_start_assistant_pattern: start_token = j + len(im_start_assistant_pattern) # 查找对应的 <|im_end|> end_token = None for k in range(start_token, len(input_ids[i]) - len(im_end_pattern) + 1): if input_ids[i][k:k+len(im_end_pattern)] == im_end_pattern: end_token = k break # 标记 assistant 部分(不包括结束token) if end_token is not None: for token_idx in range(start_token, end_token): current_mask[token_idx] = 1 j = end_token + len(im_end_pattern) else: j += 1 else: j += 1 assistant_masks.append(current_mask) inputs["assistant_masks"] = assistant_masks[0] inputs['input_ids'] = input_ids[0] truncation = kwargs.pop('truncation', False) max_length = kwargs.pop('max_length', 1024) padding = kwargs.pop('padding', False) inputs = self.process_images(images, inputs=inputs) if isinstance(inputs, UserDict): inputs = inputs.data if 'attention_mask' not in inputs: inputs['attention_mask'] = [1] * len(inputs['input_ids']) if 'assistant_masks' in inputs: inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')] inputs = self.process_inputs(inputs) if truncation and len(inputs['input_ids']) > max_length: inputs = self.truncate(inputs, max_length) if padding and len(inputs['input_ids']) < max_length: inputs = self.padding(inputs, max_length) inputs = self.to_tensor(inputs) self.check(inputs) if self.vision_token_share_pe: position_ids = self.get_position_ids(inputs) position_ids = torch.tensor([position_ids], dtype=torch.long) inputs['position_ids'] = position_ids inputs.pop('sub_image_nums', None) return BatchFeature(inputs) def get_position_ids(self, inputs: Dict[str, Any]): input_ids = inputs['input_ids'][0] image_token_lens = self.get_image_token_length(inputs) position_ids = [] i, j = 0, 0 while len(position_ids) < len(input_ids): if input_ids[len(position_ids)] == self.image_token_id: image_token_len = image_token_lens[j] assert image_token_len % self.image_token_len == 0 num_views = image_token_len // self.image_token_len for _ in range(num_views): position_ids += [i] * self.image_token_len # 同一个图像的所有 token 共享相同的位置编码 i += 1 j += 1 else: position_ids.append(i) i += 1 assert j == len(image_token_lens) and len(position_ids) == len(input_ids), \ f"Wrong position_ids, {j} != {len(image_token_lens)} or {len(position_ids)} != {len(input_ids)}" return position_ids def process_images(self, images, inputs): images = [load_image(img) for img in images] if len(images) > 0: processed_images = [] sub_image_nums = [] for image in images: if len(images) > 1: # for multi images, remove the split strategy sub_images = dynamic_preprocess( image, min_num=1, max_num=1, image_size=self.image_size[0], use_thumbnail=True) else: sub_images = dynamic_preprocess( image, min_num=self.min_sub_img, max_num=self.max_sub_img, image_size=self.image_size[0], use_thumbnail=True) sub_image_nums.append(len(sub_images)) processed_images += sub_images # print([_img.size for _img in processed_images]) pixel_values = self.image_processor.preprocess( images=processed_images, return_tensors="pt" )["pixel_values"] # (N, c, h, w) else: pixel_values = torch.zeros(( 1, 3, self.image_size[0], self.image_size[1]), dtype=torch.float32 ) sub_image_nums = [] inputs['pixel_values'] = pixel_values inputs['sub_image_nums'] = sub_image_nums return inputs def truncate(self, inputs: Dict[str, Any], max_length: int): assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed." inputs['input_ids'] = inputs['input_ids'][:max_length] inputs['attention_mask'] = inputs['attention_mask'][:max_length] if 'prompt_mask' in inputs: inputs['prompt_mask'] = inputs['prompt_mask'][:max_length] return inputs def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]: sub_image_nums = inputs.get('sub_image_nums', None) if sub_image_nums is None or len(sub_image_nums) == 0: return [] image_token_lens = [_num * self.num_image_token for _num in sub_image_nums] return image_token_lens def process_inputs(self, inputs: Dict[str, Any]): graft_token_lens = self._get_graft_token_length(inputs) inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id) inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate') if 'prompt_mask' in inputs: inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate') return inputs def _graft_token(self, seq, graft_token_lens, value): if value == 'replicate': for i in reversed(graft_token_lens.keys()): seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:] else: for i in reversed(graft_token_lens.keys()): seq[i:] = [value] * graft_token_lens[i] + seq[i+1:] return seq def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]: image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id] image_token_lens = self.get_image_token_length(inputs) assert len(image_token_pos) == len(image_token_lens), \ "Wrong image token count, " \ f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})" graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens)) return graft_token_lens def check(self, inputs: Dict[str, Any]): image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item() image_embed_count = sum(self.get_image_token_length(inputs)) assert image_embed_token_count == image_embed_count, \ "Wrong image embed token count, " \ f"image_embed_token_count({image_embed_token_count}) != image_embed_count({image_embed_count})" def padding(self, inputs: Dict[str, Any], max_length: int): padding_len = max_length - len(inputs['input_ids']) inputs['input_ids'] += [self.pad_token_id] * padding_len inputs['attention_mask'] += [0] * padding_len if 'prompt_mask' in inputs: inputs['prompt_mask'] += [0] * padding_len return inputs def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs): if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() text = self.tokenizer.decode(token_ids, **kwargs) return text def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs): if isinstance(sequences, torch.Tensor): sequences = sequences.tolist() texts = self.tokenizer.batch_decode(sequences, **kwargs) return texts def to_tensor(self, inputs): inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long) inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool) if 'prompt_mask' in inputs: inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool) return inputs @property def pad_token_id(self): return self._pad_token_id def __repr__(self): pass def __str__(self): return 'DMLLMProcessor' def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=512, use_thumbnail=True): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images