liuxh0319 commited on
Commit
5c26951
·
verified ·
1 Parent(s): bec8f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -12
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # storygen_tts_final.py(修复会话状态管理)
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
@@ -13,30 +13,194 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
 
16
- # ...(其他代码保持不变)...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def main():
19
- # ...(界面配置保持不变)...
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # 文件上传组件
22
  uploaded_file = st.file_uploader(
23
  "Choose your magic image",
24
  type=["jpg", "png", "jpeg"],
25
  help="Upload photos of pets, toys or adventures!",
26
- key="uploader" # 关键点:保留key定义
27
  )
28
-
29
- # ...(处理上传文件部分保持不变)...
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # 显示结果
32
- if st.session_state.generated:
33
- # ...(结果展示部分保持不变)...
34
-
 
 
 
 
 
 
 
 
 
 
 
35
  st.markdown("---")
36
  if st.button("Create New Story", use_container_width=True):
37
- # 仅重置必要状态
38
  st.session_state.generated = False
39
- # 自动清除上传文件(通过重新运行实现)
 
40
  st.rerun()
41
 
42
  if __name__ == "__main__":
 
1
+ # storygen_tts_final.py
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
 
13
  import numpy as np
14
  from PIL import Image
15
 
16
+ # 初始化模型(CPU优化版)
17
+ @st.cache_resource(show_spinner="🔮 Loading magic models...")
18
+ def load_models():
19
+ """加载所有需要的AI模型"""
20
+ try:
21
+ # 图像描述模型
22
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
24
+
25
+ # 文本生成pipeline
26
+ story_generator = pipeline(
27
+ "text-generation",
28
+ model="openai-community/gpt2",
29
+ device_map="auto",
30
+ torch_dtype=torch.float32
31
+ )
32
+
33
+ # 语音合成模型
34
+ tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts")
35
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
36
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
37
+
38
+ # 加载说话者嵌入数据集
39
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
40
+
41
+ return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset
42
+ except Exception as e:
43
+ st.error(f"模型加载失败: {str(e)}")
44
+ raise
45
+
46
+ def generate_story(image, blip_processor, blip_model, story_generator):
47
+ """生成高质量儿童故事"""
48
+ inputs = blip_processor(image, return_tensors="pt")
49
+
50
+ # 生成图像描述
51
+ caption_ids = blip_model.generate(
52
+ **inputs,
53
+ max_new_tokens=100,
54
+ num_beams=5,
55
+ early_stopping=True,
56
+ temperature=0.9
57
+ )
58
+ caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
59
+
60
+ # 构建故事生成提示词
61
+ prompt = f"""Based on this image: {caption}
62
+ Write a magical story for children with:
63
+ 1. Talking animals
64
+ 2. Happy ending
65
+ 3. Sound effects (*whoosh*, *giggle*)
66
+ 4. 50-100 words
67
+
68
+ Story:"""
69
+
70
+ # 使用GPT-2生成故事
71
+ generated = story_generator(
72
+ prompt,
73
+ max_length=300,
74
+ min_length=150,
75
+ num_return_sequences=1,
76
+ temperature=0.85,
77
+ repetition_penalty=2.0
78
+ )
79
+
80
+ # 提取生成文本并清理
81
+ full_text = generated[0]['generated_text']
82
+ story = full_text.split("Story:")[-1].strip()
83
+ return story[:580].replace(caption, "").strip()
84
+
85
+ def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
86
+ """文本转语音(修复版)"""
87
+ try:
88
+ # 输入预处理
89
+ inputs = processor(
90
+ text=text,
91
+ return_tensors="pt",
92
+ padding="max_length",
93
+ max_length=600,
94
+ truncation=True,
95
+ voice_preset=None
96
+ )
97
+ input_ids = inputs["input_ids"].to(torch.int64)
98
+
99
+ # 动态调整说话者嵌入维度
100
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"])
101
+ speaker_embeddings = speaker_embeddings.unsqueeze(0).repeat(1, input_ids.shape[1], 1)
102
+
103
+ with torch.no_grad():
104
+ speech = model.generate_speech(
105
+ input_ids=input_ids,
106
+ speaker_embeddings=speaker_embeddings,
107
+ vocoder=vocoder
108
+ )
109
+
110
+ # 音频处理
111
+ audio_array = speech.numpy().astype(np.float32)
112
+ max_val = np.max(np.abs(audio_array)) + 1e-8
113
+ audio_array = 0.9 * audio_array / max_val
114
+
115
+ return audio_array, 16000
116
+ except Exception as e:
117
+ st.error(f"语音生成失败: {str(e)}")
118
+ raise
119
 
120
  def main():
121
+ # 初始化会话状态
122
+ required_states = {
123
+ 'generated': False,
124
+ 'audio': None,
125
+ 'story': ""
126
+ }
127
+ for key, val in required_states.items():
128
+ if key not in st.session_state:
129
+ st.session_state[key] = val
130
 
131
+ # 界面配置
132
+ st.set_page_config(
133
+ page_title="Magic Story Box",
134
+ page_icon="🧙",
135
+ layout="centered"
136
+ )
137
+
138
+ st.title("🧚♀️ Magic Story Box")
139
+ st.markdown("---")
140
+ st.write("Upload an image to get your magical story!")
141
+
142
+ # 加载模型
143
+ try:
144
+ (blip_proc, blip_model, story_gen,
145
+ tts_proc, tts_model, vocoder, embeddings) = load_models()
146
+ except:
147
+ return
148
+
149
  # 文件上传组件
150
  uploaded_file = st.file_uploader(
151
  "Choose your magic image",
152
  type=["jpg", "png", "jpeg"],
153
  help="Upload photos of pets, toys or adventures!",
154
+ key="uploader"
155
  )
156
+
157
+ # 处理上传文件
158
+ if uploaded_file and not st.session_state.generated:
159
+ try:
160
+ image = Image.open(uploaded_file).convert("RGB")
161
+ st.image(image, caption="Your Magic Picture ✨", use_container_width=True)
162
+
163
+ with st.status("Creating Magic...", expanded=True) as status:
164
+ # 生成故事
165
+ st.write("🔍 Reading the image...")
166
+ story = generate_story(image, blip_proc, blip_model, story_gen)
167
+
168
+ # 生成语音
169
+ st.write("🔊 Adding sounds...")
170
+ audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings)
171
+
172
+ # 保存结果
173
+ st.session_state.story = story
174
+ st.session_state.audio = (audio_array, sr)
175
+ status.update(label="Ready!", state="complete", expanded=False)
176
+
177
+ st.session_state.generated = True
178
+ st.rerun()
179
+
180
+ except Exception as e:
181
+ st.error(f"Magic failed: {str(e)}")
182
+
183
  # 显示结果
184
+ if st.session_state.get('generated', False):
185
+ st.markdown("---")
186
+ st.subheader("Your Story 📖")
187
+ st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>',
188
+ unsafe_allow_html=True)
189
+
190
+ st.markdown("---")
191
+ st.subheader("Listen 🎧")
192
+ if st.session_state.audio is not None:
193
+ audio_data, sr = st.session_state.audio
194
+ st.audio(audio_data, sample_rate=sr)
195
+ else:
196
+ st.warning("Audio not available")
197
+
198
  st.markdown("---")
199
  if st.button("Create New Story", use_container_width=True):
200
+ # 安全重置状态
201
  st.session_state.generated = False
202
+ st.session_state.audio = None
203
+ st.session_state.story = ""
204
  st.rerun()
205
 
206
  if __name__ == "__main__":