{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "import PIL\n",
    "from PIL import Image\n",
    "from IPython.display import HTML\n",
    "from pyramid_dit import PyramidDiTForVideoGeneration\n",
    "from IPython.display import Image as ipython_image\n",
    "from diffusers.utils import load_image, export_to_video, export_to_gif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant='diffusion_transformer_768p'         # For high resolution\n",
    "# variant='diffusion_transformer_384p'       # For low resolution\n",
    "\n",
    "model_path = \"/home/jinyang06/models/pyramid-flow\"   # The downloaded checkpoint dir\n",
    "model_dtype = 'bf16'\n",
    "\n",
    "device_id = 0\n",
    "torch.cuda.set_device(device_id)\n",
    "\n",
    "model = PyramidDiTForVideoGeneration(\n",
    "    model_path,\n",
    "    model_dtype,\n",
    "    model_variant=variant,\n",
    ")\n",
    "\n",
    "model.vae.to(\"cuda\")\n",
    "model.dit.to(\"cuda\")\n",
    "model.text_encoder.to(\"cuda\")\n",
    "\n",
    "if model_dtype == \"bf16\":\n",
    "    torch_dtype = torch.bfloat16 \n",
    "elif model_dtype == \"fp16\":\n",
    "    torch_dtype = torch.float16\n",
    "else:\n",
    "    torch_dtype = torch.float32\n",
    "\n",
    "\n",
    "def show_video(ori_path, rec_path, width=\"100%\"):\n",
    "    html = ''\n",
    "    if ori_path is not None:\n",
    "        html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
    "        <source src=\"{ori_path}\" type=\"video/mp4\">\n",
    "        </video>\n",
    "        \"\"\"\n",
    "    \n",
    "    html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
    "    <source src=\"{rec_path}\" type=\"video/mp4\">\n",
    "    </video>\n",
    "    \"\"\"\n",
    "    return HTML(html)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Text-to-Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors\"\n",
    "\n",
    "# used for 384p model variant\n",
    "# width = 640\n",
    "# height = 384\n",
    "\n",
    "# used for 768p model variant\n",
    "width = 1280\n",
    "height = 768\n",
    "\n",
    "temp = 16   # temp in [1, 31] <=> frame in [1, 241] <=> duration in [0, 10s]\n",
    "\n",
    "model.vae.enable_tiling()\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
    "    frames = model.generate(\n",
    "        prompt=prompt,\n",
    "        num_inference_steps=[20, 20, 20],\n",
    "        video_num_inference_steps=[10, 10, 10],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        temp=temp,\n",
    "        guidance_scale=9.0,         # The guidance for the first frame\n",
    "        video_guidance_scale=5.0,   # The guidance for the other video latent\n",
    "        output_type=\"pil\",\n",
    "        save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
    "    )\n",
    "\n",
    "export_to_video(frames, \"./text_to_video_sample.mp4\", fps=24)\n",
    "show_video(None, \"./text_to_video_sample.mp4\", \"70%\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Image-to-Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = 'assets/the_great_wall.jpg'\n",
    "image = Image.open(image_path).convert(\"RGB\")\n",
    "\n",
    "width = 1280\n",
    "height = 768\n",
    "temp = 16\n",
    "\n",
    "image = image.resize((width, height))\n",
    "\n",
    "display(image)\n",
    "\n",
    "prompt = \"FPV flying over the Great Wall\"\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
    "    frames = model.generate_i2v(\n",
    "        prompt=prompt,\n",
    "        input_image=image,\n",
    "        num_inference_steps=[10, 10, 10],\n",
    "        temp=temp,\n",
    "        guidance_scale=7.0,\n",
    "        video_guidance_scale=4.0,\n",
    "        output_type=\"pil\",\n",
    "        save_memory=True,         # If you have enough GPU memory, set it to `False` to improve vae decoding speed\n",
    "    )\n",
    "\n",
    "export_to_video(frames, \"./image_to_video_sample.mp4\", fps=24)\n",
    "show_video(None, \"./image_to_video_sample.mp4\", \"70%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}