RajatMalviya commited on
Commit
3517ef3
·
verified ·
1 Parent(s): 44a6a1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -31
app.py CHANGED
@@ -1,48 +1,357 @@
1
- import streamlit as st
2
- import tempfile
3
  import os
4
- import librosa # For audio resampling
 
 
 
 
 
 
5
  import torch
6
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
 
 
 
7
 
8
- # Load the model and processor
9
- @st.cache_resource
10
- def load_model():
11
- processor = WhisperProcessor.from_pretrained("ivrit-ai/whisper-large-v3-turbo")
12
- model = WhisperForConditionalGeneration.from_pretrained("ivrit-ai/whisper-large-v3-turbo")
13
- return processor, model
 
 
 
 
 
 
14
 
15
- processor, model = load_model()
 
16
 
17
- st.title("Hebrew Speech-to-Text Transcription")
 
 
 
18
 
19
- # Upload audio file
20
- uploaded_file = st.file_uploader("Upload an audio file (WAV, MP3, OGG)", type=["wav", "mp3", "ogg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- if uploaded_file is not None:
23
- # Save the uploaded file to a temporary location
24
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
25
- temp_audio.write(uploaded_file.read())
26
- temp_audio_path = temp_audio.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
28
  try:
29
- # Load and resample audio to 16kHz (required by Whisper)
30
- speech_array, sampling_rate = librosa.load(temp_audio_path, sr=16000)
 
 
 
31
 
32
- # Preprocess audio
33
- inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt")
34
 
35
- # Generate transcription
 
 
 
36
  with torch.no_grad():
37
- predicted_ids = model.generate(inputs.input_features)
 
 
 
 
 
 
38
 
39
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
40
 
41
- st.subheader("Transcription:")
42
- st.write(transcription)
 
43
 
 
 
44
  except Exception as e:
45
- st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Clean up the temporary file
48
- os.remove(temp_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import ffmpeg
3
  import os
4
+ import uuid
5
+ import requests
6
+ import tempfile
7
+ import shutil
8
+ import re
9
+ import time
10
+ import concurrent.futures
11
  import torch
12
+ from pathlib import Path
13
+ from dotenv import load_dotenv
14
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
15
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
16
+ from PIL import Image
17
 
18
+ # Add GPU decorator for Hugging Face Spaces
19
+ try:
20
+ from spaces import GPU
21
+ use_gpu = True
22
+ @GPU
23
+ def get_gpu():
24
+ return True
25
+ # Call the function to trigger GPU allocation
26
+ get_gpu()
27
+ except ImportError:
28
+ use_gpu = False
29
+ print("Running without GPU acceleration")
30
 
31
+ # Load environment variables from .env file if it exists
32
+ load_dotenv()
33
 
34
+ # Global variables to hold models (lazy loading)
35
+ llava_model = None
36
+ llava_processor = None
37
+ stable_diffusion_pipeline = None
38
 
39
+ def load_llava_model():
40
+ """Load LLaVA model for image captioning"""
41
+ global llava_model, llava_processor
42
+
43
+ if llava_model is None or llava_processor is None:
44
+ print("Loading LLaVA model for image analysis...")
45
+ model_id = "llava-hf/llava-1.5-7b-hf"
46
+
47
+ # Load processor and model (with reduced precision for memory efficiency)
48
+ llava_processor = AutoProcessor.from_pretrained(model_id)
49
+ llava_model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ torch_dtype=torch.float16,
52
+ device_map="auto"
53
+ )
54
+
55
+ return llava_model, llava_processor
56
 
57
+ def load_stable_diffusion_model():
58
+ """Load Stable Diffusion model for Ghibli-style image generation"""
59
+ global stable_diffusion_pipeline
60
+
61
+ if stable_diffusion_pipeline is None:
62
+ print("Loading Stable Diffusion model for image generation...")
63
+ model_id = "nitrosocke/Ghibli-Diffusion"
64
+
65
+ # Load the pipeline with precision to balance memory usage and quality
66
+ stable_diffusion_pipeline = StableDiffusionPipeline.from_pretrained(
67
+ model_id,
68
+ torch_dtype=torch.float16,
69
+ safety_checker=None # Disable safety checker for performance
70
+ )
71
+
72
+ # Move to GPU if available
73
+ if torch.cuda.is_available():
74
+ stable_diffusion_pipeline = stable_diffusion_pipeline.to("cuda")
75
+
76
+ # Use the DPM-Solver++ scheduler for better quality at lower steps
77
+ stable_diffusion_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
78
+ stable_diffusion_pipeline.scheduler.config,
79
+ algorithm_type="dpmsolver++",
80
+ use_karras_sigmas=True
81
+ )
82
+
83
+ return stable_diffusion_pipeline
84
 
85
+ def analyze_image_with_llava(image_path):
86
+ """Process a single frame with LLaVA to generate a description"""
87
  try:
88
+ # Load the model if not already loaded
89
+ model, processor = load_llava_model()
90
+
91
+ # Load the image
92
+ image = Image.open(image_path)
93
 
94
+ # Prompt for Ghibli-specific description
95
+ prompt = "Describe this image in detail, focusing on elements that would be important to recreate it in Studio Ghibli animation style."
96
 
97
+ # Process the image and generate text
98
+ inputs = processor(prompt, image, return_tensors="pt").to(model.device)
99
+
100
+ # Generate with appropriate parameters
101
  with torch.no_grad():
102
+ output = model.generate(
103
+ **inputs,
104
+ max_new_tokens=300,
105
+ do_sample=True,
106
+ temperature=0.7,
107
+ top_p=0.9,
108
+ )
109
 
110
+ # Decode the output
111
+ generated_text = processor.decode(output[0], skip_special_tokens=True)
112
 
113
+ # Extract just the response part (remove the prompt)
114
+ response = generated_text.split(prompt)[-1].strip()
115
+ print(f"LLaVA analysis for frame {os.path.basename(image_path)}: {response[:150]}...")
116
 
117
+ return response
118
+
119
  except Exception as e:
120
+ import traceback
121
+ print(f"Error analyzing image {os.path.basename(image_path)}: {str(e)}")
122
+ print(traceback.format_exc())
123
+ return f"Error analyzing image: {str(e)}"
124
+
125
+ def generate_ghibli_image(image_description, style_prompt, output_path):
126
+ """Generate a Ghibli-style image based on the description using Stable Diffusion"""
127
+ try:
128
+ # Load the model if not already loaded
129
+ pipeline = load_stable_diffusion_model()
130
+
131
+ # Combine the image description with the style prompt
132
+ full_prompt = f"{image_description}. {style_prompt}. Hand-drawn animation style, soft colors, attention to detail, Miyazaki aesthetic."
133
+
134
+ # Ensure prompt isn't too long
135
+ if len(full_prompt) > 500:
136
+ full_prompt = full_prompt[:497] + "..."
137
+
138
+ # Generate the image
139
+ with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
140
+ image = pipeline(
141
+ prompt=full_prompt,
142
+ negative_prompt="3d, cgi, low quality, blurry, distorted, deformed",
143
+ num_inference_steps=30,
144
+ guidance_scale=7.5,
145
+ width=768,
146
+ height=768,
147
+ ).images[0]
148
+
149
+ # Save the generated image
150
+ image.save(output_path)
151
+ print(f"Successfully saved stylized frame: {os.path.basename(output_path)}")
152
+ return True
153
 
154
+ except Exception as e:
155
+ import traceback
156
+ print(f"Error generating image: {str(e)}")
157
+ print(traceback.format_exc())
158
+ return False
159
+
160
+ def process_frame(frame_path, style_prompt):
161
+ """Process a single frame with LLaVA analysis and Stable Diffusion generation"""
162
+ try:
163
+ # First use LLaVA to analyze the image
164
+ image_description = analyze_image_with_llava(frame_path)
165
+
166
+ if image_description.startswith("Error"):
167
+ return False
168
+
169
+ # Now use Stable Diffusion to generate a stylized version
170
+ result = generate_ghibli_image(image_description, style_prompt, frame_path)
171
+
172
+ return result
173
+
174
+ except Exception as e:
175
+ import traceback
176
+ print(f"Error processing frame {os.path.basename(frame_path)}: {str(e)}")
177
+ print(traceback.format_exc())
178
+ return False
179
+
180
+ def stylize_video(video_path, style_prompt, num_frames=15):
181
+ try:
182
+ # Create temp directories
183
+ temp_dir = tempfile.mkdtemp()
184
+ input_filename = os.path.join(temp_dir, "input.mp4")
185
+ frames_dir = os.path.join(temp_dir, "frames")
186
+ os.makedirs(frames_dir, exist_ok=True)
187
+
188
+ # Save the input video to a temporary file
189
+ if isinstance(video_path, str):
190
+ if video_path.startswith('http'):
191
+ # It's a URL, download it
192
+ response = requests.get(video_path, stream=True)
193
+ with open(input_filename, 'wb') as f:
194
+ for chunk in response.iter_content(chunk_size=8192):
195
+ f.write(chunk)
196
+ elif os.path.exists(video_path):
197
+ # It's a file path, copy it
198
+ shutil.copy(video_path, input_filename)
199
+ else:
200
+ return None, f"Video file not found: {video_path}"
201
+ else:
202
+ # Assume it's binary data
203
+ with open(input_filename, "wb") as f:
204
+ f.write(video_path)
205
+
206
+ # Make sure the video file exists
207
+ if not os.path.exists(input_filename):
208
+ return None, "Failed to save input video"
209
+
210
+ # Extract frames - using lower fps for longer videos (1 frame per second)
211
+ ffmpeg.input(input_filename).output(f"{frames_dir}/%04d.png", vf="fps=1").run(quiet=True)
212
+
213
+ # Check if frames were extracted
214
+ frames = sorted([os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith('.png')])
215
+ if not frames:
216
+ return None, "No frames were extracted from the video"
217
+
218
+ # Limit to a maximum number of frames for reasonable processing times
219
+ if len(frames) > num_frames:
220
+ # Take evenly distributed frames
221
+ indices = [int(i * (len(frames) - 1) / (num_frames - 1)) for i in range(num_frames)]
222
+ frames = [frames[i] for i in indices]
223
+
224
+ print(f"Processing {len(frames)} frames")
225
+
226
+ # Process frames sequentially if we're using a GPU (to avoid CUDA OOM errors)
227
+ # Otherwise, use a modest level of parallelism
228
+ if torch.cuda.is_available():
229
+ # Sequential processing to avoid CUDA OOM errors
230
+ processed_frames = []
231
+ for i, frame in enumerate(frames):
232
+ success = process_frame(frame, style_prompt)
233
+ if success:
234
+ processed_frames.append(frame)
235
+ print(f"Completed frame {os.path.basename(frame)} ({i+1}/{len(frames)})")
236
+ else:
237
+ print(f"Failed to process frame {os.path.basename(frame)}")
238
+
239
+ # Free up CUDA cache between frames
240
+ torch.cuda.empty_cache()
241
+ else:
242
+ # Process frames in parallel with limited workers if no GPU
243
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
244
+ futures = {executor.submit(process_frame, frame, style_prompt): frame for frame in frames}
245
+
246
+ # Collect results
247
+ processed_frames = []
248
+ for future in concurrent.futures.as_completed(futures):
249
+ frame = futures[future]
250
+ if future.result():
251
+ processed_frames.append(frame)
252
+ print(f"Completed frame {os.path.basename(frame)} ({len(processed_frames)}/{len(frames)})")
253
+
254
+ if not processed_frames:
255
+ return None, "Failed to process any frames. Please check the logs for more information."
256
+
257
+ # Even if not all frames were processed, try to create a video with what we have
258
+ print(f"Successfully processed {len(processed_frames)}/{len(frames)} frames")
259
+
260
+ # Ensure frames are in the correct order (important for video continuity)
261
+ processed_frames.sort()
262
+
263
+ # Reassemble frames into video
264
+ output_filename = os.path.join(temp_dir, "stylized.mp4")
265
+
266
+ # Use a higher bitrate and better codec for higher quality
267
+ ffmpeg.input(f"{frames_dir}/%04d.png", framerate=1) \
268
+ .output(output_filename, vcodec='libx264', pix_fmt='yuv420p', crf=18) \
269
+ .run(quiet=True)
270
+
271
+ # Check if the output file exists and has content
272
+ if not os.path.exists(output_filename) or os.path.getsize(output_filename) == 0:
273
+ return None, "Failed to create output video"
274
+
275
+ # Copy to a persistent location for Gradio to serve
276
+ os.makedirs("outputs", exist_ok=True)
277
+ persistent_output = os.path.join("outputs", f"stylized_{uuid.uuid4()}.mp4")
278
+ shutil.copy(output_filename, persistent_output)
279
+
280
+ # Return the relative path (Gradio can handle this)
281
+ print(f"Output video created at: {persistent_output}")
282
+
283
+ # Cleanup temp files
284
+ shutil.rmtree(temp_dir)
285
+
286
+ return persistent_output, f"Video stylized successfully with {len(processed_frames)} frames!"
287
+
288
+ except Exception as e:
289
+ import traceback
290
+ traceback_str = traceback.format_exc()
291
+ print(f"Error: {str(e)}\n{traceback_str}")
292
+ return None, f"Error: {str(e)}"
293
+
294
+ # Use Gradio examples feature with local files
295
+ example_videos = [
296
+ ["sample_video.mp4", "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"]
297
+ ]
298
+
299
+ with gr.Blocks(title="Video-to-Ghibli Style Converter (Open Source)") as iface:
300
+ gr.Markdown("# Video-to-Ghibli Style Converter (Open Source)")
301
+ gr.Markdown("Upload a video and convert it to Studio Ghibli animation style using LLaVA and Stable Diffusion.")
302
+
303
+ with gr.Row():
304
+ with gr.Column(scale=2):
305
+ # Main input column
306
+ video_input = gr.Video(label="Upload Video (up to 15 seconds)")
307
+
308
+ style_prompt = gr.Textbox(
309
+ label="Style Prompt",
310
+ value="Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
311
+ )
312
+
313
+ num_frames_slider = gr.Slider(
314
+ minimum=5,
315
+ maximum=15,
316
+ value=10,
317
+ step=1,
318
+ label="Number of frames to process"
319
+ )
320
+
321
+ submit_btn = gr.Button("Stylize Video", variant="primary")
322
+
323
+ with gr.Column(scale=2):
324
+ # Output column
325
+ video_output = gr.Video(label="Stylized Video")
326
+ status_output = gr.Textbox(label="Status", value="Ready. Upload a video to start.")
327
+
328
+ submit_btn.click(
329
+ fn=stylize_video,
330
+ inputs=[video_input, style_prompt, num_frames_slider],
331
+ outputs=[video_output, status_output]
332
+ )
333
+
334
+ gr.Markdown("""
335
+ ## Instructions
336
+ 1. Upload a video up to 15 seconds long
337
+ 2. Customize the style prompt if desired
338
+ 3. Adjust the number of frames to process (fewer = faster)
339
+ 4. Click "Stylize Video" and wait for processing
340
+
341
+ ## Example Style Prompts
342
+ - "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
343
+ - "Studio Ghibli style with magical and dreamy atmosphere"
344
+ - "Nostalgic Studio Ghibli animation style with watercolor backgrounds and clean linework"
345
+ - "Ghibli-inspired animation with vibrant colors and fantasy elements"
346
+
347
+ Note: Each frame is analyzed by LLaVA-1.5-7B and then transformed by Stable Diffusion (Ghibli-Diffusion model).
348
+ Videos are processed at 1 frame per second to keep processing time reasonable.
349
+
350
+ ## Technical Details
351
+ - Image Analysis: Using LLaVA-1.5-7B for frame understanding and description
352
+ - Image Generation: Using Stable Diffusion (nitrosocke/Ghibli-Diffusion) for style transfer
353
+ - All processing happens locally - no API keys needed!
354
+ """)
355
+
356
+ if __name__ == "__main__":
357
+ iface.launch()