| import os |
| import json |
| import argparse |
| import numpy as np |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from torchvision import transforms as T |
|
|
| from data.dataset_for_clean_descrip import PoseHICODetDataset |
| from data.convsersation import Conversation_For_Action_Pharse as Conversation |
|
|
| import re |
| from dataclasses import dataclass |
|
|
| from tools.vlm_backend import build_batch_tensors, decode_generated_text, load_model_and_processor |
|
|
| def disable_torch_init(): |
| """ |
| Disable the redundant torch default initialization to accelerate model creation. |
| """ |
| setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
| import os, json |
| import torch |
|
|
| class StreamingJsonArrayWriter: |
| def __init__(self, output_path): |
| self.output_path = output_path |
| self.file = None |
| self.is_first = True |
|
|
| def __enter__(self): |
| self.file = open(self.output_path, "w", encoding="utf-8") |
| self.file.write("[\n") |
| self.file.flush() |
| return self |
|
|
| def write(self, item): |
| if not self.is_first: |
| self.file.write(",\n") |
| json.dump(item, self.file, ensure_ascii=False, indent=2) |
| self.file.flush() |
| self.is_first = False |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.file is not None: |
| self.file.write("\n]\n") |
| self.file.close() |
|
|
| @dataclass |
| class DataCollatorForSupervisedDataset(object): |
| def __init__(self, processor, data_path): |
| self.processor = processor |
| self.conv = Conversation( |
| system='', |
| data_path=data_path |
| ) |
| |
| def __call__(self, data_dicts): |
| """Collate examples for supervised fine-tuning.""" |
| batch_prompts = [] |
| batch_images = [] |
| result_meta = [] |
| |
| for i, data_dict in enumerate(data_dicts): |
| batch_images.append(data_dict['image']) |
| batch_prompts.append(self.conv.get_prompt(data_dict['meta'])) |
| result_meta.append(data_dict['meta']) |
|
|
| messages = [] |
| for prompt in zip(batch_prompts): |
| messages.append([ |
| {"role": "system", |
| "content":[ |
| {"type": "text", |
| "text": self.conv.system},]}, |
| {"role": "user", |
| "content":[ |
| {"type": "image"}, |
| {"type": "text", |
| "text": prompt},]}, |
| ]) |
|
|
| batch_tensors = build_batch_tensors( |
| processor=self.processor, |
| prompts=batch_prompts, |
| images=batch_images, |
| system_prompt=self.conv.system, |
| ) |
| return batch_tensors, result_meta |
|
|
| @torch.no_grad() |
| def worker(model, processor, dataset, args, output_dir): |
|
|
| rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| indices = list(range(rank, len(dataset), world_size)) |
| print("==>" + " Worker {} Started, responsible for {} images".format(rank, len(indices))) |
|
|
| sub_dataset = torch.utils.data.Subset(dataset, indices) |
| batch_size = 16 |
| data_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=DataCollatorForSupervisedDataset(processor, args.data_path)) |
| output_path = os.path.join(args.output_dir, f'labels_{rank}.json') |
| |
| with StreamingJsonArrayWriter(output_path) as writer: |
| for batch_tensors, result_meta in tqdm(data_loader): |
| |
| input_ids = batch_tensors['input_ids'].cuda() |
| batch_tensors = {k: v.cuda() for k, v in batch_tensors.items() if isinstance(v, torch.Tensor)} |
| with torch.inference_mode(): |
| output_dict = model.generate(do_sample=False, |
| output_scores=True, |
| return_dict_in_generate=True, |
| max_new_tokens=1600, |
| output_logits=True, |
| **batch_tensors,) |
| |
| output_ids = output_dict['sequences'] |
| |
| for input_id, output_id, meta in zip(input_ids, output_ids, result_meta): |
| input_token_len = input_id.shape[0] |
| n_diff_input_output = (input_id != output_id[:input_token_len]).sum().item() |
| if n_diff_input_output > 0: |
| print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids') |
| output = decode_generated_text(processor, output_id, input_id) |
| meta['action_description'] = output |
| writer.write(meta) |
|
|
| def eval_model(args): |
| torch.distributed.init_process_group(backend='nccl') |
| rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| |
| print('Init process group: world_size: {}, rank: {}'.format(world_size, rank)) |
| torch.cuda.set_device(rank) |
|
|
| disable_torch_init() |
| backend_name, model, processor = load_model_and_processor( |
| model_path=args.model_path, |
| backend=args.model_backend, |
| torch_dtype=args.torch_dtype, |
| trust_remote_code=True, |
| ) |
| print(f'Using model backend: {backend_name}') |
| model = model.cuda() |
| model.eval() |
| |
| dataset = PoseHICODetDataset( |
| data_path=args.data_path, |
| multimodal_cfg=dict(image_folder=os.path.join(args.data_path, 'Images/images/train2015'), |
| data_augmentation=False, |
| image_size=336,),) |
| worker(model, processor, dataset, args, args.output_dir) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-path", type=str, default="facebook/opt-350m") |
| parser.add_argument("--data-path", type=str, default="") |
| parser.add_argument("--output-dir", type=str, default="") |
| parser.add_argument("--model-backend", type=str, default="auto") |
| parser.add_argument("--torch-dtype", type=str, default="bfloat16") |
| args = parser.parse_args() |
|
|
| eval_model(args) |
| |
|
|