MobileWorldModel
Collection
5 items • Updated
MobileWorld-Diffusion is an image-to-image mobile GUI world model. Given the current screenshot and a candidate action, it renders the predicted next screenshot.
The model is based on a Qwen-Image-Edit style pipeline: the current screenshot is passed as the edit/reference image, and the action is provided through a text prompt.
Provide:
edit_image: the current GUI screenshot.prompt: the action-conditioned next-state rendering prompt.Predict the next page state via image from this current screenshot using action description "{action_desc}" and action target "{target_desc}" and relative coordinates "[{rx:.3f}, {ry:.3f}]".
action_desc: natural-language action description, for example click, scroll down, input text: pizza, or open app: Gmail.target_desc: target UI element description, for example search input field, back button, or point(536, 1280).rx: normalized x coordinate in [0, 1].ry: normalized y coordinate in [0, 1].For non-coordinate actions, use a reasonable default such as [0.500, 0.500] and put the main action information in action_desc.
The expected output is an image representing the predicted next screen state after executing the action on the input screenshot.
Predict the next page state via image from this current screenshot using action description "click" and action target "circular back button in the top-left corner" and relative coordinates "[0.060, 0.073]".
import math
import torch
from PIL import Image
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
PROMPT_TEMPLATE = (
'Predict the next page state via image from this current screenshot '
'using action description "{action_desc}" and action target "{target_desc}" '
'and relative coordinates "[{rx:.3f}, {ry:.3f}]".'
)
def target_hw(src_w, src_h, target_area=1024 * 1024, divisor=32):
ratio = src_w / src_h
w = math.sqrt(target_area * ratio)
h = w / ratio
w = max(divisor, round(w / divisor) * divisor)
h = max(divisor, round(h / divisor) * divisor)
return int(h), int(w)
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
model_id="Qwen/Qwen-Image-Edit-2511",
origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",
),
ModelConfig(
model_id="Qwen/Qwen-Image",
origin_file_pattern="text_encoder/model*.safetensors",
),
ModelConfig(
model_id="Qwen/Qwen-Image",
origin_file_pattern="vae/diffusion_pytorch_model.safetensors",
),
],
tokenizer_config=None,
processor_config=ModelConfig(
model_id="Qwen/Qwen-Image-Edit",
origin_file_pattern="processor/",
),
)
# Load the fine-tuned MobileWorld-Diffusion checkpoint if it is provided as a
# separate safetensors file in your local setup.
state_dict = load_state_dict("step-5060.safetensors")
pipe.dit.load_state_dict(state_dict, strict=False, assign=True)
pipe.dit.to(device=getattr(pipe, "device", "cuda"), dtype=getattr(pipe, "torch_dtype", torch.bfloat16))
src = Image.open("screenshot_0.png").convert("RGB")
h, w = target_hw(src.size[0], src.size[1])
prompt = PROMPT_TEMPLATE.format(
action_desc="click",
target_desc="circular back button in the top-left corner",
rx=0.060,
ry=0.073,
)
out = pipe(
prompt=prompt,
edit_image=src,
seed=42,
num_inference_steps=40,
height=h,
width=w,
zero_cond_t=True,
)
out.save("rendered_next_screen.png")
Coordinates are relative to the input screenshot:
rx = x / image_width
ry = y / image_height
For datasets that already store coordinates in normalized 0..1000 space, use:
rx = x / 1000
ry = y / 1000