Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# storygen_tts_final.py
|
2 |
import streamlit as st
|
3 |
from transformers import (
|
4 |
BlipForConditionalGeneration,
|
@@ -69,8 +69,8 @@ Story:"""
|
|
69 |
# 使用GPT-2生成故事
|
70 |
generated = story_generator(
|
71 |
prompt,
|
72 |
-
max_length=
|
73 |
-
min_length=
|
74 |
num_return_sequences=1,
|
75 |
temperature=0.85,
|
76 |
repetition_penalty=2.0
|
@@ -79,17 +79,15 @@ Story:"""
|
|
79 |
# 提取生成文本并清理
|
80 |
full_text = generated[0]['generated_text']
|
81 |
story = full_text.split("Story:")[-1].strip()
|
82 |
-
return story[:
|
83 |
|
84 |
def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
|
85 |
"""文本转语音"""
|
86 |
try:
|
87 |
-
# 添加文本截断和长度限制
|
88 |
inputs = processor(
|
89 |
text=text,
|
90 |
return_tensors="pt",
|
91 |
-
|
92 |
-
max_length=600 # 确保不超过模型限制
|
93 |
)
|
94 |
input_ids = inputs["input_ids"].to(torch.int64)
|
95 |
|
@@ -166,7 +164,6 @@ def main():
|
|
166 |
|
167 |
except Exception as e:
|
168 |
st.error(f"Magic failed: {str(e)}")
|
169 |
-
st.session_state.generated = False # 确保失败时重置状态
|
170 |
|
171 |
# 显示结果
|
172 |
if st.session_state.generated:
|
@@ -179,15 +176,6 @@ def main():
|
|
179 |
st.subheader("Listen 🎧")
|
180 |
audio_data, sr = st.session_state.audio
|
181 |
st.audio(audio_data, sample_rate=sr)
|
182 |
-
|
183 |
-
st.markdown("---")
|
184 |
-
if st.button("Create New Story", use_container_width=True):
|
185 |
-
# 彻底清除所有相关状态
|
186 |
-
keys_to_clear = ['generated', 'story', 'audio']
|
187 |
-
for key in keys_to_clear:
|
188 |
-
if key in st.session_state:
|
189 |
-
del st.session_state[key]
|
190 |
-
st.rerun()
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
main()
|
|
|
1 |
+
# storygen_tts_final.py
|
2 |
import streamlit as st
|
3 |
from transformers import (
|
4 |
BlipForConditionalGeneration,
|
|
|
69 |
# 使用GPT-2生成故事
|
70 |
generated = story_generator(
|
71 |
prompt,
|
72 |
+
max_length=100,
|
73 |
+
min_length=50,
|
74 |
num_return_sequences=1,
|
75 |
temperature=0.85,
|
76 |
repetition_penalty=2.0
|
|
|
79 |
# 提取生成文本并清理
|
80 |
full_text = generated[0]['generated_text']
|
81 |
story = full_text.split("Story:")[-1].strip()
|
82 |
+
return story[:600].replace(caption, "").strip()
|
83 |
|
84 |
def text_to_speech(text, processor, model, vocoder, embeddings_dataset):
|
85 |
"""文本转语音"""
|
86 |
try:
|
|
|
87 |
inputs = processor(
|
88 |
text=text,
|
89 |
return_tensors="pt",
|
90 |
+
voice_preset=None
|
|
|
91 |
)
|
92 |
input_ids = inputs["input_ids"].to(torch.int64)
|
93 |
|
|
|
164 |
|
165 |
except Exception as e:
|
166 |
st.error(f"Magic failed: {str(e)}")
|
|
|
167 |
|
168 |
# 显示结果
|
169 |
if st.session_state.generated:
|
|
|
176 |
st.subheader("Listen 🎧")
|
177 |
audio_data, sr = st.session_state.audio
|
178 |
st.audio(audio_data, sample_rate=sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
if __name__ == "__main__":
|
181 |
main()
|