SkyReelsV2 / app.py
1inkusFace's picture
Update app.py
2a1145d verified
# --- NEW ---
# Add necessary imports for new features
import tempfile
import cv2
from moviepy.editor import VideoFileClip, concatenate_videoclips
import spaces
import os
import torch
import subprocess
# Ensure flash-attn is uninstalled if it causes conflicts in the HF environment
try:
package_to_uninstall = "flash-attn"
print(f"Attempting to uninstall {package_to_uninstall}...")
command = ["python", "-m", "pip", "uninstall", "-y", package_to_uninstall]
result = subprocess.run(command, check=True, capture_output=True, text=True)
print(f"Successfully uninstalled {package_to_uninstall}.")
except subprocess.CalledProcessError as e:
print(f"Could not uninstall {package_to_uninstall} (may not be installed): {e}")
import gradio as gr
import imageio
import time
import random
import gc
from PIL import Image
# Import necessary components from the cloned repository
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import Image2VideoPipeline, resizecrop
# --- Global Configuration & Model Loading ---
MODEL_ID = "Skywork/SkyReels-V2-I2V-14B-720P"
HEIGHT = 720
WIDTH = 720
OUTPUT_DIR = "video_out"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Downloading and loading model... This may take a while.")
cached_model_path = download_model(MODEL_ID)
# 1. Initialize the pipeline as you did originally
pipe = Image2VideoPipeline(
model_path=cached_model_path,
dit_path=cached_model_path,
use_usp=False,
offload=True
)
# 2. Manually convert the key model components to bfloat16
print("Converting model components to bfloat16...")
pipe.transformer = pipe.transformer.to(dtype=torch.bfloat16)
# The VAE might also benefit from conversion
if hasattr(pipe, 'vae'):
pipe.vae = pipe.vae.to(dtype=torch.bfloat16)
print("Model loaded and converted successfully.")
# --- NEW: Helper functions for new features ---
def use_last_frame_as_input(video_filepath):
"""
Efficiently extracts the last frame of a video using OpenCV.
"""
if not video_filepath or not os.path.exists(video_filepath):
gr.Warning("No video clip available to get the last frame from.")
return None
cap = None
try:
print(f"Reading last frame from {video_filepath} using OpenCV...")
cap = cv2.VideoCapture(video_filepath)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_count < 1:
raise ValueError("Video file could not be read or contains no frames.")
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
ret, frame = cap.read()
if not ret or frame is None:
raise ValueError("Failed to read the last frame from the video.")
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
print("Successfully extracted last frame.")
return pil_image
except Exception as e:
print(f"Error extracting last frame with OpenCV: {e}")
gr.Error(f"Failed to extract the last frame: {e}")
return None
finally:
if cap is not None and cap.isOpened():
cap.release()
def stitch_videos(clips_list):
"""
Stitches a list of video file paths into a single video.
"""
if not clips_list or len(clips_list) < 2:
raise gr.Error("You need at least two clips to stitch them together!")
print(f"Stitching {len(clips_list)} clips...")
try:
video_clips = [VideoFileClip(clip_path) for clip_path in clips_list]
final_clip = concatenate_videoclips(video_clips, method="compose")
temp_dir = tempfile.mkdtemp()
final_output_path = os.path.join(temp_dir, f"stitched_video_{random.randint(10000,99999)}.mp4")
final_clip.write_videofile(final_output_path, codec="libx264", audio=False, threads=4, preset='ultrafast')
for clip in video_clips:
clip.close()
print(f"Final video saved to {final_output_path}")
return final_output_path
except Exception as e:
print(f"Error during video stitching: {e}")
raise gr.Error(f"Failed to stitch videos: {e}")
def clear_clips():
"""
Resets the clip list and associated UI components.
"""
return [], "Clips created: 0", None, None
# --- MODIFIED: Main Inference Function ---
@spaces.GPU(duration=60)
def generate_video(input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed, clips_list):
if input_image is None:
raise gr.Error("You must upload an initial image.")
if not prompt:
raise gr.Error("Prompt cannot be empty.")
if seed == -1:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cuda").manual_seed(int(seed))
image = Image.fromarray(input_image).convert("RGB")
processed_image = resizecrop(image, HEIGHT, WIDTH)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, worst quality, low quality, JPEG compression residue, ugly, deformed."
kwargs = {
"image": processed_image,
"prompt": prompt,
"negative_prompt": negative_prompt,
"num_frames": int(num_frames),
"num_inference_steps": int(inference_steps),
"guidance_scale": guidance_scale,
"shift": 5.0, # Recommended value for I2V
"generator": generator,
"height": HEIGHT,
"width": WIDTH,
}
print(f"Generating video with seed: {seed}")
start_time = time.time()
with torch.amp.autocast('cuda',dtype=torch.bfloat16), torch.no_grad():
video_frames = pipe(**kwargs)[0]
end_time = time.time()
print(f"Inference took {end_time - start_time:.2f} seconds.")
safe_prompt = "".join(c for c in prompt if c.isalnum() or c in " _-").strip()[:50]
output_filename = f"{safe_prompt}_{seed}.mp4"
output_path = os.path.join(OUTPUT_DIR, output_filename)
imageio.mimwrite(output_path, video_frames, fps=int(fps), quality=8, output_params=["-loglevel", "error"])
print(f"Video saved to {output_path}")
gc.collect()
torch.cuda.empty_cache()
# --- MODIFIED ---
# Update the clips list and counter text
updated_clips_list = clips_list + [output_path]
counter_text = f"Clips created: {len(updated_clips_list)}"
# Return values to update all relevant UI components
return output_path, gr.update(visible=True), updated_clips_list, counter_text
# --- MODIFIED: Gradio UI ---
with gr.Blocks(css="footer {display: none !important}") as demo:
# --- NEW ---
# State to store the list of generated clip paths
clips_state = gr.State([])
gr.Markdown(
"""
# SkyReels-V2 Image-to-Video Clip Stitcher
### Model: Skywork/SkyReels-V2-I2V-14B-720P
Generate short video clips iteratively and stitch them together into a longer animation.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Initial Image")
prompt = gr.Textbox(label="Prompt", placeholder="e.g., A cinematic shot of a car driving on a rainy street at night.")
with gr.Accordion("Advanced Settings", open=False):
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=5.0, step=0.5, label="Guidance Scale")
inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps")
num_frames = gr.Slider(minimum=25, maximum=145, value=97, step=8, label="Number of Frames")
fps = gr.Slider(minimum=8, maximum=30, value=24, step=1, label="Frames Per Second (FPS)")
seed = gr.Number(value=-1, label="Seed (-1 for random)")
with gr.Column():
output_video = gr.Video(label="Last Generated Clip")
# --- NEW ---
# Button to use the last frame of the output as the new input
use_last_frame_button = gr.Button("Use Last Frame as Input Image", visible=False)
# --- NEW ---
# UI for stitching controls
with gr.Accordion("Stitching Controls", open=True):
clip_counter_display = gr.Markdown("Clips created: 0")
with gr.Row():
stitch_button = gr.Button("🎬 Stitch All Clips")
clear_button = gr.Button("🗑️ Clear All Clips")
final_video_output = gr.Video(label="Final Stitched Video", interactive=False)
run_button = gr.Button("Generate Video Clip", variant="primary")
# --- MODIFIED ---
# Update the click event to handle the new state and outputs
run_button.click(
fn=generate_video,
inputs=[input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed, clips_state],
outputs=[output_video, use_last_frame_button, clips_state, clip_counter_display]
)
# --- NEW ---
# Add click events for the new buttons
use_last_frame_button.click(
fn=use_last_frame_as_input,
inputs=[output_video],
outputs=[input_image]
)
stitch_button.click(
fn=stitch_videos,
inputs=[clips_state],
outputs=[final_video_output]
)
clear_button.click(
fn=clear_clips,
outputs=[clips_state, clip_counter_display, output_video, final_video_output]
)
if __name__ == "__main__":
demo.queue().launch(debug=True)