Training Diffusion Models with Reinforcement Learning
Paper • 2305.13301 • Published • 5
import torch
from diffusers import DiffusionPipeline
# switch to "mps" for apple devices
pipe = DiffusionPipeline.from_pretrained("kvablack/ddpo-alignment", dtype=torch.bfloat16, device_map="cuda")
prompt = "a horse playing chess"
image = pipe(prompt).images[0]This model was finetuned from Stable Diffusion v1-4 using DDPO and a reward function that uses LLaVA to measure prompt-image alignment. See the project website for more details.
The model was finetuned for 200 iterations with a batch size of 256 samples per iteration. During finetuning, we used prompts of the form: "a(n) <animal> <activity>". We selected the animal and activity from the following lists, so try those for the best results. However, we also observed limited generalization to other prompts.
Activities:
Animals: