File size: 6,080 Bytes
ef98b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# storygen_tts_final.py
import streamlit as st
from transformers import (
    BlipForConditionalGeneration,
    BlipProcessor,
    AutoProcessor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
    pipeline
)
from datasets import load_dataset
import torch
import numpy as np
from PIL import Image

# 初始化模型(CPU优化版)
@st.cache_resource
def load_models():
    """加载所有需要的AI模型"""
    try:
        # 图像描述模型
        blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
        
        # 文本生成pipeline
        story_generator = pipeline(
            "text-generation",
            model="openai-community/gpt2",
            device_map="auto"
        )
        
        # 语音合成模型
        tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts")
        tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
        vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
        
        # 加载说话者嵌入数据集
        embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
        
        return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset
    except Exception as e:
        st.error(f"模型加载失败: {str(e)}")
        raise

def generate_story(image, blip_processor, blip_model, story_generator):
    """生成高质量儿童故事"""
    inputs = blip_processor(image, return_tensors="pt")
    
    # 生成图像描述
    caption_ids = blip_model.generate(
        **inputs,
        max_new_tokens=100,
        num_beams=5,
        early_stopping=True,
        temperature=0.9
    )
    caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
    
    # 构建故事生成提示词
    prompt = f"""Based on this image: {caption}
Write a magical story for children with:
1. Talking animals
2. Happy ending
3. Sound effects (*whoosh*, *giggle*)
4. 50-100 words

Story:"""
    
    # 使用GPT-2生成故事
    generated = story_generator(
        prompt,
        max_length=100,
        min_length=50,
        num_return_sequences=1,
        temperature=0.85,
        repetition_penalty=2.0
    )
    
    # 提取生成文本并清理
    full_text = generated[0]['generated_text']
    story = full_text.split("Story:")[-1].strip()
    return story[:600].replace(caption, "").strip()

def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
    """文本转语音"""
    try:
        inputs = processor(
            text=text, 
            return_tensors="pt",
            voice_preset=None
        )
        input_ids = inputs["input_ids"].to(torch.int64)
        
        # 随机选择一个说话者嵌入
        speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
        
        with torch.no_grad():
            speech = model.generate_speech(
                input_ids=input_ids,
                speaker_embeddings=speaker_embeddings,
                vocoder=vocoder
            )
        
        audio_array = speech.numpy()
        audio_array = audio_array / np.max(np.abs(audio_array))
        return audio_array, 16000
    except Exception as e:
        st.error(f"语音生成失败: {str(e)}")
        raise

def main():
    # 界面配置
    st.set_page_config(
        page_title="Magic Story Box",
        page_icon="🧙",
        layout="centered"
    )
    
    st.title("🧚♀️ Magic Story Box")
    st.markdown("---")
    st.write("Upload an image to get your magical story!")
    
    # 初始化会话状态
    if 'generated' not in st.session_state:
        st.session_state.generated = False
    
    # 加载模型
    try:
        (blip_proc, blip_model, story_gen, 
         tts_proc, tts_model, vocoder, embeddings) = load_models()
    except:
        return
    
    # 文件上传组件
    uploaded_file = st.file_uploader(
        "Choose your magic image",
        type=["jpg", "png", "jpeg"],
        help="Upload photos of pets, toys or adventures!",
        key="uploader"
    )
    
    # 处理上传文件
    if uploaded_file and not st.session_state.generated:
        try:
            image = Image.open(uploaded_file).convert("RGB")
            st.image(image, caption="Your Magic Picture ✨", use_container_width=True)
            
            with st.status("Creating Magic...", expanded=True) as status:
                # 生成故事
                st.write("🔍 Reading the image...")
                story = generate_story(image, blip_proc, blip_model, story_gen)
                
                # 生成语音
                st.write("🔊 Adding sounds...")
                audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings)
                
                # 保存结果
                st.session_state.story = story
                st.session_state.audio = (audio_array, sr)
                status.update(label="Ready!", state="complete", expanded=False)
            
            st.session_state.generated = True
            st.rerun()
            
        except Exception as e:
            st.error(f"Magic failed: {str(e)}")
    
    # 显示结果
    if st.session_state.generated:
        st.markdown("---")
        st.subheader("Your Story 📖")
        st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>', 
                   unsafe_allow_html=True)
        
        st.markdown("---")
        st.subheader("Listen 🎧")
        audio_data, sr = st.session_state.audio
        st.audio(audio_data, sample_rate=sr)
        
        st.markdown("---")
        if st.button("Create New Story", use_container_width=True):
            st.session_state.generated = False
            st.session_state.uploader = None
            st.rerun()

if __name__ == "__main__":
    main()