liuxh0319 commited on
Commit
52d4a8e
·
verified ·
1 Parent(s): bbff29a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -17
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # storygen_tts_final.py (修正版)
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
@@ -69,8 +69,8 @@ Story:"""
69
  # 使用GPT-2生成故事
70
  generated = story_generator(
71
  prompt,
72
- max_length=300,
73
- min_length=150,
74
  num_return_sequences=1,
75
  temperature=0.85,
76
  repetition_penalty=2.0
@@ -79,17 +79,15 @@ Story:"""
79
  # 提取生成文本并清理
80
  full_text = generated[0]['generated_text']
81
  story = full_text.split("Story:")[-1].strip()
82
- return story[:500].replace(caption, "").strip() # 缩短最大长度
83
 
84
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
  """文本转语音"""
86
  try:
87
- # 添加文本截断和长度限制
88
  inputs = processor(
89
  text=text,
90
  return_tensors="pt",
91
- truncation=True,
92
- max_length=600 # 确保不超过模型限制
93
  )
94
  input_ids = inputs["input_ids"].to(torch.int64)
95
 
@@ -166,7 +164,6 @@ def main():
166
 
167
  except Exception as e:
168
  st.error(f"Magic failed: {str(e)}")
169
- st.session_state.generated = False # 确保失败时重置状态
170
 
171
  # 显示结果
172
  if st.session_state.generated:
@@ -179,15 +176,6 @@ def main():
179
  st.subheader("Listen 🎧")
180
  audio_data, sr = st.session_state.audio
181
  st.audio(audio_data, sample_rate=sr)
182
-
183
- st.markdown("---")
184
- if st.button("Create New Story", use_container_width=True):
185
- # 彻底清除所有相关状态
186
- keys_to_clear = ['generated', 'story', 'audio']
187
- for key in keys_to_clear:
188
- if key in st.session_state:
189
- del st.session_state[key]
190
- st.rerun()
191
 
192
  if __name__ == "__main__":
193
  main()
 
1
+ # storygen_tts_final.py
2
  import streamlit as st
3
  from transformers import (
4
  BlipForConditionalGeneration,
 
69
  # 使用GPT-2生成故事
70
  generated = story_generator(
71
  prompt,
72
+ max_length=100,
73
+ min_length=50,
74
  num_return_sequences=1,
75
  temperature=0.85,
76
  repetition_penalty=2.0
 
79
  # 提取生成文本并清理
80
  full_text = generated[0]['generated_text']
81
  story = full_text.split("Story:")[-1].strip()
82
+ return story[:600].replace(caption, "").strip()
83
 
84
  def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
85
  """文本转语音"""
86
  try:
 
87
  inputs = processor(
88
  text=text,
89
  return_tensors="pt",
90
+ voice_preset=None
 
91
  )
92
  input_ids = inputs["input_ids"].to(torch.int64)
93
 
 
164
 
165
  except Exception as e:
166
  st.error(f"Magic failed: {str(e)}")
 
167
 
168
  # 显示结果
169
  if st.session_state.generated:
 
176
  st.subheader("Listen 🎧")
177
  audio_data, sr = st.session_state.audio
178
  st.audio(audio_data, sample_rate=sr)
 
 
 
 
 
 
 
 
 
179
 
180
  if __name__ == "__main__":
181
  main()