liuxh0319 commited on
Commit
ef98b47
·
verified ·
1 Parent(s): 8e32feb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # storygen_tts_final.py
2
+ import streamlit as st
3
+ from transformers import (
4
+ BlipForConditionalGeneration,
5
+ BlipProcessor,
6
+ AutoProcessor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ pipeline
10
+ )
11
+ from datasets import load_dataset
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+ # 初始化模型(CPU优化版)
17
+ @st.cache_resource
18
+ def load_models():
19
+ """加载所有需要的AI模型"""
20
+ try:
21
+ # 图像描述模型
22
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
24
+
25
+ # 文本生成pipeline
26
+ story_generator = pipeline(
27
+ "text-generation",
28
+ model="openai-community/gpt2",
29
+ device_map="auto"
30
+ )
31
+
32
+ # 语音合成模型
33
+ tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts")
34
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
35
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
36
+
37
+ # 加载说话者嵌入数据集
38
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
39
+
40
+ return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset
41
+ except Exception as e:
42
+ st.error(f"模型加载失败: {str(e)}")
43
+ raise
44
+
45
+ def generate_story(image, blip_processor, blip_model, story_generator):
46
+ """生成高质量儿童故事"""
47
+ inputs = blip_processor(image, return_tensors="pt")
48
+
49
+ # 生成图像描述
50
+ caption_ids = blip_model.generate(
51
+ **inputs,
52
+ max_new_tokens=100,
53
+ num_beams=5,
54
+ early_stopping=True,
55
+ temperature=0.9
56
+ )
57
+ caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
58
+
59
+ # 构建故事生成提示词
60
+ prompt = f"""Based on this image: {caption}
61
+ Write a magical story for children with:
62
+ 1. Talking animals
63
+ 2. Happy ending
64
+ 3. Sound effects (*whoosh*, *giggle*)
65
+ 4. 50-100 words
66
+
67
+ Story:"""
68
+
69
+ # 使用GPT-2生成故事
70
+ generated = story_generator(
71
+ prompt,
72
+ max_length=100,
73
+ min_length=50,
74
+ num_return_sequences=1,
75
+ temperature=0.85,
76
+ repetition_penalty=2.0
77
+ )
78
+
79
+ # 提取生成文本并清理
80
+ full_text = generated[0]['generated_text']
81
+ story = full_text.split("Story:")[-1].strip()
82
+ return story[:600].replace(caption, "").strip()
83
+
84
+ def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
+ """文本转语音"""
86
+ try:
87
+ inputs = processor(
88
+ text=text,
89
+ return_tensors="pt",
90
+ voice_preset=None
91
+ )
92
+ input_ids = inputs["input_ids"].to(torch.int64)
93
+
94
+ # 随机选择一个说话者嵌入
95
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
96
+
97
+ with torch.no_grad():
98
+ speech = model.generate_speech(
99
+ input_ids=input_ids,
100
+ speaker_embeddings=speaker_embeddings,
101
+ vocoder=vocoder
102
+ )
103
+
104
+ audio_array = speech.numpy()
105
+ audio_array = audio_array / np.max(np.abs(audio_array))
106
+ return audio_array, 16000
107
+ except Exception as e:
108
+ st.error(f"语音生成失败: {str(e)}")
109
+ raise
110
+
111
+ def main():
112
+ # 界面配置
113
+ st.set_page_config(
114
+ page_title="Magic Story Box",
115
+ page_icon="🧙",
116
+ layout="centered"
117
+ )
118
+
119
+ st.title("🧚♀️ Magic Story Box")
120
+ st.markdown("---")
121
+ st.write("Upload an image to get your magical story!")
122
+
123
+ # 初始化会话状态
124
+ if 'generated' not in st.session_state:
125
+ st.session_state.generated = False
126
+
127
+ # 加载模型
128
+ try:
129
+ (blip_proc, blip_model, story_gen,
130
+ tts_proc, tts_model, vocoder, embeddings) = load_models()
131
+ except:
132
+ return
133
+
134
+ # 文件上传组件
135
+ uploaded_file = st.file_uploader(
136
+ "Choose your magic image",
137
+ type=["jpg", "png", "jpeg"],
138
+ help="Upload photos of pets, toys or adventures!",
139
+ key="uploader"
140
+ )
141
+
142
+ # 处理上传文件
143
+ if uploaded_file and not st.session_state.generated:
144
+ try:
145
+ image = Image.open(uploaded_file).convert("RGB")
146
+ st.image(image, caption="Your Magic Picture ✨", use_container_width=True)
147
+
148
+ with st.status("Creating Magic...", expanded=True) as status:
149
+ # 生成故事
150
+ st.write("🔍 Reading the image...")
151
+ story = generate_story(image, blip_proc, blip_model, story_gen)
152
+
153
+ # 生成语音
154
+ st.write("🔊 Adding sounds...")
155
+ audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings)
156
+
157
+ # 保存结果
158
+ st.session_state.story = story
159
+ st.session_state.audio = (audio_array, sr)
160
+ status.update(label="Ready!", state="complete", expanded=False)
161
+
162
+ st.session_state.generated = True
163
+ st.rerun()
164
+
165
+ except Exception as e:
166
+ st.error(f"Magic failed: {str(e)}")
167
+
168
+ # 显示结果
169
+ if st.session_state.generated:
170
+ st.markdown("---")
171
+ st.subheader("Your Story 📖")
172
+ st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>',
173
+ unsafe_allow_html=True)
174
+
175
+ st.markdown("---")
176
+ st.subheader("Listen 🎧")
177
+ audio_data, sr = st.session_state.audio
178
+ st.audio(audio_data, sample_rate=sr)
179
+
180
+ st.markdown("---")
181
+ if st.button("Create New Story", use_container_width=True):
182
+ st.session_state.generated = False
183
+ st.session_state.uploader = None
184
+ st.rerun()
185
+
186
+ if __name__ == "__main__":
187
+ main()