| |
|
|
| import torch |
| import torchvision.transforms as T |
| from PIL import Image |
| import io |
| import json |
|
|
| |
| CLASS_LABELS = [ |
| "glove_outline", |
| "webbing", |
| "thumb", |
| "palm_pocket", |
| "hand", |
| "glove_exterior" |
| ] |
|
|
| |
| |
| |
| def load_model(): |
| model = torch.load("pytorch_model.bin", map_location="cpu") |
| model.eval() |
| return model |
|
|
| model = load_model() |
|
|
| |
| |
| |
| transform = T.Compose([ |
| T.Resize((720, 1280)), |
| T.ToTensor() |
| ]) |
|
|
| def preprocess(input_bytes): |
| image = Image.open(io.BytesIO(input_bytes)).convert("RGB") |
| tensor = transform(image).unsqueeze(0) |
| return tensor |
|
|
| |
| |
| |
| class DummyInput: |
| def __init__(self, image_tensor): |
| B, C, H, W = image_tensor.shape |
| self.images = image_tensor |
| self.masks = [torch.zeros(B, H, W, dtype=torch.bool)] |
| self.num_frames = 1 |
| self.original_size = [(H, W)] |
| self.target_size = [(H, W)] |
| self.point_coords = [None] |
| self.point_labels = [None] |
| self.boxes = [None] |
| self.mask_inputs = torch.zeros(B, 1, H, W) |
| self.video_mask = torch.zeros(B, 1, H, W) |
| self.flat_obj_to_img_idx = [[0]] |
|
|
| |
| |
| |
| def postprocess(output_tensor): |
| if isinstance(output_tensor, dict) and "masks" in output_tensor: |
| logits = output_tensor["masks"] |
| else: |
| logits = output_tensor |
| pred = torch.argmax(logits, dim=1)[0].cpu().numpy() |
| return pred.tolist() |
|
|
| |
| |
| |
| def infer(payload): |
| if isinstance(payload, bytes): |
| image_tensor = preprocess(payload) |
| elif isinstance(payload, dict) and "inputs" in payload: |
| from base64 import b64decode |
| image_tensor = preprocess(b64decode(payload["inputs"])) |
| else: |
| raise ValueError("Unsupported input format") |
|
|
| input_obj = DummyInput(image_tensor) |
|
|
| with torch.no_grad(): |
| output = model(input_obj) |
|
|
| mask = postprocess(output) |
| return { |
| "mask": mask, |
| "classes": CLASS_LABELS |
| } |
|
|