Spaces:
Sleeping
Sleeping
# storygen_tts_final.py | |
import streamlit as st | |
from transformers import ( | |
BlipForConditionalGeneration, | |
BlipProcessor, | |
AutoProcessor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan, | |
pipeline | |
) | |
from datasets import load_dataset | |
import torch | |
import numpy as np | |
from PIL import Image | |
# 初始化模型(CPU优化版) | |
def load_models(): | |
"""加载所有需要的AI模型""" | |
try: | |
# 图像描述模型 | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
# 文本生成pipeline | |
story_generator = pipeline( | |
"text-generation", | |
model="openai-community/gpt2", | |
device_map="auto" | |
) | |
# 语音合成模型 | |
tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts") | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
# 加载说话者嵌入数据集 | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset | |
except Exception as e: | |
st.error(f"模型加载失败: {str(e)}") | |
raise | |
def generate_story(image, blip_processor, blip_model, story_generator): | |
"""生成高质量儿童故事""" | |
inputs = blip_processor(image, return_tensors="pt") | |
# 生成图像描述 | |
caption_ids = blip_model.generate( | |
**inputs, | |
max_new_tokens=100, | |
num_beams=5, | |
early_stopping=True, | |
temperature=0.9 | |
) | |
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True) | |
# 构建故事生成提示词 | |
prompt = f"""Based on this image: {caption} | |
Write a magical story for children with: | |
1. Talking animals | |
2. Happy ending | |
3. Sound effects (*whoosh*, *giggle*) | |
4. 50-100 words | |
Story:""" | |
# 使用GPT-2生成故事 | |
generated = story_generator( | |
prompt, | |
max_length=100, | |
min_length=50, | |
num_return_sequences=1, | |
temperature=0.85, | |
repetition_penalty=2.0 | |
) | |
# 提取生成文本并清理 | |
full_text = generated[0]['generated_text'] | |
story = full_text.split("Story:")[-1].strip() | |
return story[:600].replace(caption, "").strip() | |
def text_to_speech(text, processor, model, vocoder, embeddings_dataset): | |
"""文本转语音""" | |
try: | |
inputs = processor( | |
text=text, | |
return_tensors="pt", | |
voice_preset=None | |
) | |
input_ids = inputs["input_ids"].to(torch.int64) | |
# 随机选择一个说话者嵌入 | |
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
with torch.no_grad(): | |
speech = model.generate_speech( | |
input_ids=input_ids, | |
speaker_embeddings=speaker_embeddings, | |
vocoder=vocoder | |
) | |
audio_array = speech.numpy() | |
audio_array = audio_array / np.max(np.abs(audio_array)) | |
return audio_array, 16000 | |
except Exception as e: | |
st.error(f"语音生成失败: {str(e)}") | |
raise | |
def main(): | |
# 界面配置 | |
st.set_page_config( | |
page_title="Magic Story Box", | |
page_icon="🧙", | |
layout="centered" | |
) | |
st.title("🧚♀️ Magic Story Box") | |
st.markdown("---") | |
st.write("Upload an image to get your magical story!") | |
# 初始化会话状态 | |
if 'generated' not in st.session_state: | |
st.session_state.generated = False | |
# 加载模型 | |
try: | |
(blip_proc, blip_model, story_gen, | |
tts_proc, tts_model, vocoder, embeddings) = load_models() | |
except: | |
return | |
# 文件上传组件 | |
uploaded_file = st.file_uploader( | |
"Choose your magic image", | |
type=["jpg", "png", "jpeg"], | |
help="Upload photos of pets, toys or adventures!", | |
key="uploader" | |
) | |
# 处理上传文件 | |
if uploaded_file and not st.session_state.generated: | |
try: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Your Magic Picture ✨", use_container_width=True) | |
with st.status("Creating Magic...", expanded=True) as status: | |
# 生成故事 | |
st.write("🔍 Reading the image...") | |
story = generate_story(image, blip_proc, blip_model, story_gen) | |
# 生成语音 | |
st.write("🔊 Adding sounds...") | |
audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings) | |
# 保存结果 | |
st.session_state.story = story | |
st.session_state.audio = (audio_array, sr) | |
status.update(label="Ready!", state="complete", expanded=False) | |
st.session_state.generated = True | |
st.rerun() | |
except Exception as e: | |
st.error(f"Magic failed: {str(e)}") | |
# 显示结果 | |
if st.session_state.generated: | |
st.markdown("---") | |
st.subheader("Your Story 📖") | |
st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>', | |
unsafe_allow_html=True) | |
st.markdown("---") | |
st.subheader("Listen 🎧") | |
audio_data, sr = st.session_state.audio | |
st.audio(audio_data, sample_rate=sr) | |
if __name__ == "__main__": | |
main() |