liuxh0319 commited on
Commit
fbe023a
·
verified ·
1 Parent(s): d580b27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -23
app.py CHANGED
@@ -14,7 +14,7 @@ import numpy as np
14
  from PIL import Image
15
 
16
  # 初始化模型(CPU优化版)
17
- @st.cache_resource(show_spinner="🔮 Loading magic models...")
18
  def load_models():
19
  """加载所有需要的AI模型"""
20
  try:
@@ -26,8 +26,7 @@ def load_models():
26
  story_generator = pipeline(
27
  "text-generation",
28
  model="openai-community/gpt2",
29
- device_map="auto",
30
- torch_dtype=torch.float32
31
  )
32
 
33
  # 语音合成模型
@@ -70,8 +69,8 @@ Story:"""
70
  # 使用GPT-2生成故事
71
  generated = story_generator(
72
  prompt,
73
- max_length=300,
74
- min_length=150,
75
  num_return_sequences=1,
76
  temperature=0.85,
77
  repetition_penalty=2.0
@@ -80,25 +79,20 @@ Story:"""
80
  # 提取生成文本并清理
81
  full_text = generated[0]['generated_text']
82
  story = full_text.split("Story:")[-1].strip()
83
- return story[:580].replace(caption, "").strip()
84
 
85
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
86
- """文本转语音(修复版)"""
87
  try:
88
- # 输入预处理
89
  inputs = processor(
90
- text=text,
91
  return_tensors="pt",
92
- padding="max_length",
93
- max_length=600,
94
- truncation=True,
95
  voice_preset=None
96
  )
97
  input_ids = inputs["input_ids"].to(torch.int64)
98
 
99
- # 动态调整说话者嵌入维度
100
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"])
101
- speaker_embeddings = speaker_embeddings.unsqueeze(0).repeat(1, input_ids.shape[1], 1)
102
 
103
  with torch.no_grad():
104
  speech = model.generate_speech(
@@ -107,11 +101,8 @@ def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
107
  vocoder=vocoder
108
  )
109
 
110
- # 音频处理
111
- audio_array = speech.numpy().astype(np.float32)
112
- max_val = np.max(np.abs(audio_array)) + 1e-8
113
- audio_array = 0.9 * audio_array / max_val
114
-
115
  return audio_array, 16000
116
  except Exception as e:
117
  st.error(f"语音生成失败: {str(e)}")
@@ -133,11 +124,46 @@ def main():
133
  if 'generated' not in st.session_state:
134
  st.session_state.generated = False
135
 
136
- # 加载模型(保持不变...)
 
 
 
 
 
137
 
138
- # 文件上传组件(保持不变...)
 
 
 
 
 
 
139
 
140
- # 处理上传文件(保持不变...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # 显示结果
143
  if st.session_state.generated:
 
14
  from PIL import Image
15
 
16
  # 初始化模型(CPU优化版)
17
+ @st.cache_resource
18
  def load_models():
19
  """加载所有需要的AI模型"""
20
  try:
 
26
  story_generator = pipeline(
27
  "text-generation",
28
  model="openai-community/gpt2",
29
+ device_map="auto"
 
30
  )
31
 
32
  # 语音合成模型
 
69
  # 使用GPT-2生成故事
70
  generated = story_generator(
71
  prompt,
72
+ max_length=100,
73
+ min_length=50,
74
  num_return_sequences=1,
75
  temperature=0.85,
76
  repetition_penalty=2.0
 
79
  # 提取生成文本并清理
80
  full_text = generated[0]['generated_text']
81
  story = full_text.split("Story:")[-1].strip()
82
+ return story[:600].replace(caption, "").strip()
83
 
84
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
+ """文本转语音"""
86
  try:
 
87
  inputs = processor(
88
+ text=text,
89
  return_tensors="pt",
 
 
 
90
  voice_preset=None
91
  )
92
  input_ids = inputs["input_ids"].to(torch.int64)
93
 
94
+ # 随机选择一个说话者嵌入
95
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
 
96
 
97
  with torch.no_grad():
98
  speech = model.generate_speech(
 
101
  vocoder=vocoder
102
  )
103
 
104
+ audio_array = speech.numpy()
105
+ audio_array = audio_array / np.max(np.abs(audio_array))
 
 
 
106
  return audio_array, 16000
107
  except Exception as e:
108
  st.error(f"语音生成失败: {str(e)}")
 
124
  if 'generated' not in st.session_state:
125
  st.session_state.generated = False
126
 
127
+ # 加载模型
128
+ try:
129
+ (blip_proc, blip_model, story_gen,
130
+ tts_proc, tts_model, vocoder, embeddings) = load_models()
131
+ except:
132
+ return
133
 
134
+ # 文件上传组件
135
+ uploaded_file = st.file_uploader(
136
+ "Choose your magic image",
137
+ type=["jpg", "png", "jpeg"],
138
+ help="Upload photos of pets, toys or adventures!",
139
+ key="uploader"
140
+ )
141
 
142
+ # 处理上传文件
143
+ if uploaded_file and not st.session_state.generated:
144
+ try:
145
+ image = Image.open(uploaded_file).convert("RGB")
146
+ st.image(image, caption="Your Magic Picture ✨", use_container_width=True)
147
+
148
+ with st.status("Creating Magic...", expanded=True) as status:
149
+ # 生成故事
150
+ st.write("🔍 Reading the image...")
151
+ story = generate_story(image, blip_proc, blip_model, story_gen)
152
+
153
+ # 生成语音
154
+ st.write("🔊 Adding sounds...")
155
+ audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings)
156
+
157
+ # 保存结果
158
+ st.session_state.story = story
159
+ st.session_state.audio = (audio_array, sr)
160
+ status.update(label="Ready!", state="complete", expanded=False)
161
+
162
+ st.session_state.generated = True
163
+ st.rerun()
164
+
165
+ except Exception as e:
166
+ st.error(f"Magic failed: {str(e)}")
167
 
168
  # 显示结果
169
  if st.session_state.generated: