TimeForge / multi_view.py
Ryukijano's picture
Update multi_view.py
f6b18bd verified
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from utils import create_image_grid # Changed import
class MultiViewDiffusion:
def __init__(self, device="cuda", model_id = "dylanebert/mvdream"):
self.device = device
self.pipeline = DiffusionPipeline.from_pretrained(
model_id,
custom_pipeline="dylanebert/multi-view-diffusion",
torch_dtype=torch.float16,
trust_remote_code=True,
).to(self.device)
def generate_views(self, prompt, num_views=4, guidance_scale=5, num_inference_steps=30, elevation=0):
images = self.pipeline(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
elevation=elevation,
).images
return images
if __name__ == "__main__":
mv_diff = MultiViewDiffusion()
images = mv_diff.generate_views("A futuristic city")
grid = create_image_grid(images)
grid.save("multi_view_output.png")