Spaces:
Running
on
Zero
Running
on
Zero
# --- NEW --- | |
# Add necessary imports for new features | |
import tempfile | |
import cv2 | |
from moviepy.editor import VideoFileClip, concatenate_videoclips | |
import spaces | |
import os | |
import torch | |
import subprocess | |
# Ensure flash-attn is uninstalled if it causes conflicts in the HF environment | |
try: | |
package_to_uninstall = "flash-attn" | |
print(f"Attempting to uninstall {package_to_uninstall}...") | |
command = ["python", "-m", "pip", "uninstall", "-y", package_to_uninstall] | |
result = subprocess.run(command, check=True, capture_output=True, text=True) | |
print(f"Successfully uninstalled {package_to_uninstall}.") | |
except subprocess.CalledProcessError as e: | |
print(f"Could not uninstall {package_to_uninstall} (may not be installed): {e}") | |
import gradio as gr | |
import imageio | |
import time | |
import random | |
import gc | |
from PIL import Image | |
# Import necessary components from the cloned repository | |
from skyreels_v2_infer.modules import download_model | |
from skyreels_v2_infer.pipelines import Image2VideoPipeline, resizecrop | |
# --- Global Configuration & Model Loading --- | |
MODEL_ID = "Skywork/SkyReels-V2-I2V-14B-720P" | |
HEIGHT = 720 | |
WIDTH = 720 | |
OUTPUT_DIR = "video_out" | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
print("Downloading and loading model... This may take a while.") | |
cached_model_path = download_model(MODEL_ID) | |
# 1. Initialize the pipeline as you did originally | |
pipe = Image2VideoPipeline( | |
model_path=cached_model_path, | |
dit_path=cached_model_path, | |
use_usp=False, | |
offload=True | |
) | |
# 2. Manually convert the key model components to bfloat16 | |
print("Converting model components to bfloat16...") | |
pipe.transformer = pipe.transformer.to(dtype=torch.bfloat16) | |
# The VAE might also benefit from conversion | |
if hasattr(pipe, 'vae'): | |
pipe.vae = pipe.vae.to(dtype=torch.bfloat16) | |
print("Model loaded and converted successfully.") | |
# --- NEW: Helper functions for new features --- | |
def use_last_frame_as_input(video_filepath): | |
""" | |
Efficiently extracts the last frame of a video using OpenCV. | |
""" | |
if not video_filepath or not os.path.exists(video_filepath): | |
gr.Warning("No video clip available to get the last frame from.") | |
return None | |
cap = None | |
try: | |
print(f"Reading last frame from {video_filepath} using OpenCV...") | |
cap = cv2.VideoCapture(video_filepath) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if frame_count < 1: | |
raise ValueError("Video file could not be read or contains no frames.") | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) | |
ret, frame = cap.read() | |
if not ret or frame is None: | |
raise ValueError("Failed to read the last frame from the video.") | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(frame_rgb) | |
print("Successfully extracted last frame.") | |
return pil_image | |
except Exception as e: | |
print(f"Error extracting last frame with OpenCV: {e}") | |
gr.Error(f"Failed to extract the last frame: {e}") | |
return None | |
finally: | |
if cap is not None and cap.isOpened(): | |
cap.release() | |
def stitch_videos(clips_list): | |
""" | |
Stitches a list of video file paths into a single video. | |
""" | |
if not clips_list or len(clips_list) < 2: | |
raise gr.Error("You need at least two clips to stitch them together!") | |
print(f"Stitching {len(clips_list)} clips...") | |
try: | |
video_clips = [VideoFileClip(clip_path) for clip_path in clips_list] | |
final_clip = concatenate_videoclips(video_clips, method="compose") | |
temp_dir = tempfile.mkdtemp() | |
final_output_path = os.path.join(temp_dir, f"stitched_video_{random.randint(10000,99999)}.mp4") | |
final_clip.write_videofile(final_output_path, codec="libx264", audio=False, threads=4, preset='ultrafast') | |
for clip in video_clips: | |
clip.close() | |
print(f"Final video saved to {final_output_path}") | |
return final_output_path | |
except Exception as e: | |
print(f"Error during video stitching: {e}") | |
raise gr.Error(f"Failed to stitch videos: {e}") | |
def clear_clips(): | |
""" | |
Resets the clip list and associated UI components. | |
""" | |
return [], "Clips created: 0", None, None | |
# --- MODIFIED: Main Inference Function --- | |
def generate_video(input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed, clips_list): | |
if input_image is None: | |
raise gr.Error("You must upload an initial image.") | |
if not prompt: | |
raise gr.Error("Prompt cannot be empty.") | |
if seed == -1: | |
seed = random.randint(0, 2**32 - 1) | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
image = Image.fromarray(input_image).convert("RGB") | |
processed_image = resizecrop(image, HEIGHT, WIDTH) | |
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, worst quality, low quality, JPEG compression residue, ugly, deformed." | |
kwargs = { | |
"image": processed_image, | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"num_frames": int(num_frames), | |
"num_inference_steps": int(inference_steps), | |
"guidance_scale": guidance_scale, | |
"shift": 5.0, # Recommended value for I2V | |
"generator": generator, | |
"height": HEIGHT, | |
"width": WIDTH, | |
} | |
print(f"Generating video with seed: {seed}") | |
start_time = time.time() | |
with torch.amp.autocast('cuda',dtype=torch.bfloat16), torch.no_grad(): | |
video_frames = pipe(**kwargs)[0] | |
end_time = time.time() | |
print(f"Inference took {end_time - start_time:.2f} seconds.") | |
safe_prompt = "".join(c for c in prompt if c.isalnum() or c in " _-").strip()[:50] | |
output_filename = f"{safe_prompt}_{seed}.mp4" | |
output_path = os.path.join(OUTPUT_DIR, output_filename) | |
imageio.mimwrite(output_path, video_frames, fps=int(fps), quality=8, output_params=["-loglevel", "error"]) | |
print(f"Video saved to {output_path}") | |
gc.collect() | |
torch.cuda.empty_cache() | |
# --- MODIFIED --- | |
# Update the clips list and counter text | |
updated_clips_list = clips_list + [output_path] | |
counter_text = f"Clips created: {len(updated_clips_list)}" | |
# Return values to update all relevant UI components | |
return output_path, gr.update(visible=True), updated_clips_list, counter_text | |
# --- MODIFIED: Gradio UI --- | |
with gr.Blocks(css="footer {display: none !important}") as demo: | |
# --- NEW --- | |
# State to store the list of generated clip paths | |
clips_state = gr.State([]) | |
gr.Markdown( | |
""" | |
# SkyReels-V2 Image-to-Video Clip Stitcher | |
### Model: Skywork/SkyReels-V2-I2V-14B-720P | |
Generate short video clips iteratively and stitch them together into a longer animation. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="numpy", label="Initial Image") | |
prompt = gr.Textbox(label="Prompt", placeholder="e.g., A cinematic shot of a car driving on a rainy street at night.") | |
with gr.Accordion("Advanced Settings", open=False): | |
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=5.0, step=0.5, label="Guidance Scale") | |
inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps") | |
num_frames = gr.Slider(minimum=25, maximum=145, value=97, step=8, label="Number of Frames") | |
fps = gr.Slider(minimum=8, maximum=30, value=24, step=1, label="Frames Per Second (FPS)") | |
seed = gr.Number(value=-1, label="Seed (-1 for random)") | |
with gr.Column(): | |
output_video = gr.Video(label="Last Generated Clip") | |
# --- NEW --- | |
# Button to use the last frame of the output as the new input | |
use_last_frame_button = gr.Button("Use Last Frame as Input Image", visible=False) | |
# --- NEW --- | |
# UI for stitching controls | |
with gr.Accordion("Stitching Controls", open=True): | |
clip_counter_display = gr.Markdown("Clips created: 0") | |
with gr.Row(): | |
stitch_button = gr.Button("🎬 Stitch All Clips") | |
clear_button = gr.Button("🗑️ Clear All Clips") | |
final_video_output = gr.Video(label="Final Stitched Video", interactive=False) | |
run_button = gr.Button("Generate Video Clip", variant="primary") | |
# --- MODIFIED --- | |
# Update the click event to handle the new state and outputs | |
run_button.click( | |
fn=generate_video, | |
inputs=[input_image, prompt, guidance_scale, inference_steps, num_frames, fps, seed, clips_state], | |
outputs=[output_video, use_last_frame_button, clips_state, clip_counter_display] | |
) | |
# --- NEW --- | |
# Add click events for the new buttons | |
use_last_frame_button.click( | |
fn=use_last_frame_as_input, | |
inputs=[output_video], | |
outputs=[input_image] | |
) | |
stitch_button.click( | |
fn=stitch_videos, | |
inputs=[clips_state], | |
outputs=[final_video_output] | |
) | |
clear_button.click( | |
fn=clear_clips, | |
outputs=[clips_state, clip_counter_display, output_video, final_video_output] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) |