-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathinference_chat.py
More file actions
47 lines (38 loc) · 1.69 KB
/
inference_chat.py
File metadata and controls
47 lines (38 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import os
import re
import json
from PIL import Image
from tqdm import tqdm
import torch
from g2vlm_utils import load_model_and_tokenizer, build_transform, process_conversation
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str, default='InternRobotics/G2VLM-2B-MoT')
parser.add_argument('--image-path', type=str, default='examples/25_0.jpg')
parser.add_argument('--question', type=str, default='')
args = parser.parse_args()
enable_template = True
model, tokenizer, new_token_ids , vit_image_transform, dino_transform = load_model_and_tokenizer(args)
image_transform = build_transform(pixel=768)
total_params = sum(p.numel() for p in model.parameters()) / 1e9
print(f'[test] total_params: {total_params}B')
img_path = 'examples/25_0.jpg'
question = "If the table (red point) is positioned at 2.6 meters, estimate the depth of the clothes (blue point). Calculate or judge based on the 3D center points of these objects. The unit is meter. Submit your response as one numeric value only."
post_prompt = "Please answer the question using a single word or phrase."
templated_question = "\n" + question + "\n" + post_prompt
if args.question is not None:
templated_question = args.question
print(question)
images = [Image.open(img_path).convert('RGB') ]
images, conversation = process_conversation(images, templated_question)
response = model.chat_with_recon(
tokenizer,
new_token_ids,
image_transform,
dino_transform,
images=images,
prompt=conversation,
max_length=100,
)
print('answer: ',response)