liuxh0319 commited on
Commit
bbff29a
·
verified ·
1 Parent(s): b313f04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # storygen_tts_final.py
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
@@ -69,8 +69,8 @@ Story:"""
69
  # 使用GPT-2生成故事
70
  generated = story_generator(
71
  prompt,
72
- max_length=100,
73
- min_length=50,
74
  num_return_sequences=1,
75
  temperature=0.85,
76
  repetition_penalty=2.0
@@ -79,15 +79,17 @@ Story:"""
79
  # 提取生成文本并清理
80
  full_text = generated[0]['generated_text']
81
  story = full_text.split("Story:")[-1].strip()
82
- return story[:600].replace(caption, "").strip()
83
 
84
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
  """文本转语音"""
86
  try:
 
87
  inputs = processor(
88
  text=text,
89
  return_tensors="pt",
90
- voice_preset=None
 
91
  )
92
  input_ids = inputs["input_ids"].to(torch.int64)
93
 
@@ -164,6 +166,7 @@ def main():
164
 
165
  except Exception as e:
166
  st.error(f"Magic failed: {str(e)}")
 
167
 
168
  # 显示结果
169
  if st.session_state.generated:
@@ -176,6 +179,15 @@ def main():
176
  st.subheader("Listen 🎧")
177
  audio_data, sr = st.session_state.audio
178
  st.audio(audio_data, sample_rate=sr)
 
 
 
 
 
 
 
 
 
179
 
180
  if __name__ == "__main__":
181
  main()
 
1
+ # storygen_tts_final.py (修正版)
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
 
69
  # 使用GPT-2生成故事
70
  generated = story_generator(
71
  prompt,
72
+ max_length=300,
73
+ min_length=150,
74
  num_return_sequences=1,
75
  temperature=0.85,
76
  repetition_penalty=2.0
 
79
  # 提取生成文本并清理
80
  full_text = generated[0]['generated_text']
81
  story = full_text.split("Story:")[-1].strip()
82
+ return story[:500].replace(caption, "").strip() # 缩短最大长度
83
 
84
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
  """文本转语音"""
86
  try:
87
+ # 添加文本截断和长度限制
88
  inputs = processor(
89
  text=text,
90
  return_tensors="pt",
91
+ truncation=True,
92
+ max_length=600 # 确保不超过模型限制
93
  )
94
  input_ids = inputs["input_ids"].to(torch.int64)
95
 
 
166
 
167
  except Exception as e:
168
  st.error(f"Magic failed: {str(e)}")
169
+ st.session_state.generated = False # 确保失败时重置状态
170
 
171
  # 显示结果
172
  if st.session_state.generated:
 
179
  st.subheader("Listen 🎧")
180
  audio_data, sr = st.session_state.audio
181
  st.audio(audio_data, sample_rate=sr)
182
+
183
+ st.markdown("---")
184
+ if st.button("Create New Story", use_container_width=True):
185
+ # 彻底清除所有相关状态
186
+ keys_to_clear = ['generated', 'story', 'audio']
187
+ for key in keys_to_clear:
188
+ if key in st.session_state:
189
+ del st.session_state[key]
190
+ st.rerun()
191
 
192
  if __name__ == "__main__":
193
  main()