|
import torch |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image |
|
from utils import create_image_grid |
|
|
|
class MultiViewDiffusion: |
|
def __init__(self, device="cuda", model_id = "dylanebert/mvdream"): |
|
self.device = device |
|
self.pipeline = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
custom_pipeline="dylanebert/multi-view-diffusion", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
).to(self.device) |
|
|
|
|
|
def generate_views(self, prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0): |
|
images = self.pipeline( |
|
prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
elevation=elevation, |
|
).images |
|
return images |
|
|
|
if __name__ == "__main__": |
|
mv_diff = MultiViewDiffusion() |
|
images = mv_diff.generate_views("A futuristic city") |
|
grid = create_image_grid(images) |
|
grid.save("multi_view_output.png") |