# storygen_tts_final.py import streamlit as st from transformers import ( BlipForConditionalGeneration, BlipProcessor, AutoProcessor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, pipeline ) from datasets import load_dataset import torch import numpy as np from PIL import Image # 初始化模型(CPU优化版) @st.cache_resource def load_models(): """加载所有需要的AI模型""" try: # 图像描述模型 blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # 文本生成pipeline story_generator = pipeline( "text-generation", model="openai-community/gpt2", device_map="auto" ) # 语音合成模型 tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts") tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") # 加载说话者嵌入数据集 embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset except Exception as e: st.error(f"模型加载失败: {str(e)}") raise def generate_story(image, blip_processor, blip_model, story_generator): """生成高质量儿童故事""" inputs = blip_processor(image, return_tensors="pt") # 生成图像描述 caption_ids = blip_model.generate( **inputs, max_new_tokens=100, num_beams=5, early_stopping=True, temperature=0.9 ) caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True) # 构建故事生成提示词 prompt = f"""Based on this image: {caption} Write a magical story for children with: 1. Talking animals 2. Happy ending 3. Sound effects (*whoosh*, *giggle*) 4. 50-100 words Story:""" # 使用GPT-2生成故事 generated = story_generator( prompt, max_length=100, min_length=50, num_return_sequences=1, temperature=0.85, repetition_penalty=2.0 ) # 提取生成文本并清理 full_text = generated[0]['generated_text'] story = full_text.split("Story:")[-1].strip() return story[:600].replace(caption, "").strip() def text_to_speech(text, processor, model, vocoder, embeddings_dataset): """文本转语音""" try: inputs = processor( text=text, return_tensors="pt", voice_preset=None ) input_ids = inputs["input_ids"].to(torch.int64) # 随机选择一个说话者嵌入 speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) with torch.no_grad(): speech = model.generate_speech( input_ids=input_ids, speaker_embeddings=speaker_embeddings, vocoder=vocoder ) audio_array = speech.numpy() audio_array = audio_array / np.max(np.abs(audio_array)) return audio_array, 16000 except Exception as e: st.error(f"语音生成失败: {str(e)}") raise def main(): # 界面配置 st.set_page_config( page_title="Magic Story Box", page_icon="🧙", layout="centered" ) st.title("🧚♀️ Magic Story Box") st.markdown("---") st.write("Upload an image to get your magical story!") # 初始化会话状态 if 'generated' not in st.session_state: st.session_state.generated = False # 加载模型 try: (blip_proc, blip_model, story_gen, tts_proc, tts_model, vocoder, embeddings) = load_models() except: return # 文件上传组件 uploaded_file = st.file_uploader( "Choose your magic image", type=["jpg", "png", "jpeg"], help="Upload photos of pets, toys or adventures!", key="uploader" ) # 处理上传文件 if uploaded_file and not st.session_state.generated: try: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Your Magic Picture ✨", use_container_width=True) with st.status("Creating Magic...", expanded=True) as status: # 生成故事 st.write("🔍 Reading the image...") story = generate_story(image, blip_proc, blip_model, story_gen) # 生成语音 st.write("🔊 Adding sounds...") audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings) # 保存结果 st.session_state.story = story st.session_state.audio = (audio_array, sr) status.update(label="Ready!", state="complete", expanded=False) st.session_state.generated = True st.rerun() except Exception as e: st.error(f"Magic failed: {str(e)}") # 显示结果 if st.session_state.generated: st.markdown("---") st.subheader("Your Story 📖") st.markdown(f'
{st.session_state.story}
', unsafe_allow_html=True) st.markdown("---") st.subheader("Listen 🎧") audio_data, sr = st.session_state.audio st.audio(audio_data, sample_rate=sr) if __name__ == "__main__": main()