assignment / app.py
liuxh0319's picture
Create app.py
ef98b47 verified
# storygen_tts_final.py
import streamlit as st
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
AutoProcessor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
pipeline
)
from datasets import load_dataset
import torch
import numpy as np
from PIL import Image
# 初始化模型(CPU优化版)
@st.cache_resource
def load_models():
"""加载所有需要的AI模型"""
try:
# 图像描述模型
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# 文本生成pipeline
story_generator = pipeline(
"text-generation",
model="openai-community/gpt2",
device_map="auto"
)
# 语音合成模型
tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
# 加载说话者嵌入数据集
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset
except Exception as e:
st.error(f"模型加载失败: {str(e)}")
raise
def generate_story(image, blip_processor, blip_model, story_generator):
"""生成高质量儿童故事"""
inputs = blip_processor(image, return_tensors="pt")
# 生成图像描述
caption_ids = blip_model.generate(
**inputs,
max_new_tokens=100,
num_beams=5,
early_stopping=True,
temperature=0.9
)
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
# 构建故事生成提示词
prompt = f"""Based on this image: {caption}
Write a magical story for children with:
1. Talking animals
2. Happy ending
3. Sound effects (*whoosh*, *giggle*)
4. 50-100 words
Story:"""
# 使用GPT-2生成故事
generated = story_generator(
prompt,
max_length=100,
min_length=50,
num_return_sequences=1,
temperature=0.85,
repetition_penalty=2.0
)
# 提取生成文本并清理
full_text = generated[0]['generated_text']
story = full_text.split("Story:")[-1].strip()
return story[:600].replace(caption, "").strip()
def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
"""文本转语音"""
try:
inputs = processor(
text=text,
return_tensors="pt",
voice_preset=None
)
input_ids = inputs["input_ids"].to(torch.int64)
# 随机选择一个说话者嵌入
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
with torch.no_grad():
speech = model.generate_speech(
input_ids=input_ids,
speaker_embeddings=speaker_embeddings,
vocoder=vocoder
)
audio_array = speech.numpy()
audio_array = audio_array / np.max(np.abs(audio_array))
return audio_array, 16000
except Exception as e:
st.error(f"语音生成失败: {str(e)}")
raise
def main():
# 界面配置
st.set_page_config(
page_title="Magic Story Box",
page_icon="🧙",
layout="centered"
)
st.title("🧚♀️ Magic Story Box")
st.markdown("---")
st.write("Upload an image to get your magical story!")
# 初始化会话状态
if 'generated' not in st.session_state:
st.session_state.generated = False
# 加载模型
try:
(blip_proc, blip_model, story_gen,
tts_proc, tts_model, vocoder, embeddings) = load_models()
except:
return
# 文件上传组件
uploaded_file = st.file_uploader(
"Choose your magic image",
type=["jpg", "png", "jpeg"],
help="Upload photos of pets, toys or adventures!",
key="uploader"
)
# 处理上传文件
if uploaded_file and not st.session_state.generated:
try:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Your Magic Picture ✨", use_container_width=True)
with st.status("Creating Magic...", expanded=True) as status:
# 生成故事
st.write("🔍 Reading the image...")
story = generate_story(image, blip_proc, blip_model, story_gen)
# 生成语音
st.write("🔊 Adding sounds...")
audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings)
# 保存结果
st.session_state.story = story
st.session_state.audio = (audio_array, sr)
status.update(label="Ready!", state="complete", expanded=False)
st.session_state.generated = True
st.rerun()
except Exception as e:
st.error(f"Magic failed: {str(e)}")
# 显示结果
if st.session_state.generated:
st.markdown("---")
st.subheader("Your Story 📖")
st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>',
unsafe_allow_html=True)
st.markdown("---")
st.subheader("Listen 🎧")
audio_data, sr = st.session_state.audio
st.audio(audio_data, sample_rate=sr)
st.markdown("---")
if st.button("Create New Story", use_container_width=True):
st.session_state.generated = False
st.session_state.uploader = None
st.rerun()
if __name__ == "__main__":
main()