| import argparse | |
| import time | |
| import base64 | |
| import numpy as np | |
| import requests | |
| import os | |
| from urllib.parse import urlparse | |
| from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput | |
| def download_image(image_url): | |
| parsed_url = urlparse(image_url) | |
| filename = os.path.basename(parsed_url.path) | |
| response = requests.get(image_url) | |
| if response.status_code == 200: | |
| with open(filename, 'wb') as img_file: | |
| img_file.write(response.content) | |
| return filename | |
| else: | |
| raise Exception("Failed to download image") | |
| def image_to_base64_data_uri(image_input): | |
| with open(image_input, "rb") as img_file: | |
| base64_data = base64.b64encode(img_file.read()).decode('utf-8') | |
| return base64_data | |
| def setup_argparse(): | |
| parser = argparse.ArgumentParser(description="Client for Triton Inference Server") | |
| parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process") | |
| parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = setup_argparse() | |
| triton_client = InferenceServerClient(url="localhost:8000", verbose=False) | |
| if args.image_path.startswith('http://') or args.image_path.startswith('https://'): | |
| image_path = download_image(args.image_path) | |
| else: | |
| image_path = args.image_path | |
| image_data = image_to_base64_data_uri(image_path).encode('utf-8') | |
| image_data_np = np.array([image_data], dtype=object) | |
| prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object) | |
| images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES") | |
| images_in.set_data_from_numpy(image_data_np, binary_data=True) | |
| prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES") | |
| prompt_in.set_data_from_numpy(prompt_np, binary_data=True) | |
| results_out = InferRequestedOutput(name="RESULTS", binary_data=False) | |
| start_time = time.time() | |
| response = triton_client.infer(model_name="spacellava", | |
| model_version="1", | |
| inputs=[prompt_in, images_in], | |
| outputs=[results_out]) | |
| results = response.get_response()["outputs"][0]["data"][0] | |
| print("--- %s seconds ---" % (time.time() - start_time)) | |
| print(results) | |