File size: 5,094 Bytes
a67a914
b52b52a
a67a914
 
 
 
b52b52a
20713d5
1fa6509
a67a914
1fa6509
 
 
 
 
 
b52b52a
1fa6509
 
758ffbc
1fa6509
 
 
b52b52a
1fa6509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a67a914
9bc975a
bcfeb34
20713d5
 
 
 
 
 
 
 
 
e881a5a
9bc975a
bcfeb34
20713d5
776a639
7448649
bcfeb34
7448649
776a639
1fa6509
776a639
64fedde
1fa6509
776a639
 
 
bcfeb34
776a639
64fedde
20713d5
 
 
 
e881a5a
9bc975a
bcfeb34
20713d5
1fa6509
20713d5
 
 
 
 
 
 
 
 
 
 
 
9bc975a
a67a914
bcfeb34
 
 
 
 
 
 
 
3e93456
bcfeb34
 
 
a67a914
bcfeb34
 
 
a67a914
bcfeb34
 
 
 
 
a67a914
bcfeb34
 
 
 
 
9bc975a
bcfeb34
 
 
 
 
 
 
 
 
 
9bc975a
bcfeb34
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import io
import numpy as np
from kokoro import KPipeline
import soundfile as sf
import re

# Cache model loading to optimize resource use
@st.cache_resource
def load_caption_model():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
    return processor, model

@st.cache_resource
def load_story_model():
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

@st.cache_resource
def load_audio_pipeline():
    return KPipeline(lang_code='a')

# Function to truncate story to last complete sentence within 100 words
def truncate_to_100_words(story):
    sentences = re.split(r'(?<=[.!?]) +', story)
    word_count = 0
    truncated_story = []
    for sentence in sentences:
        sentence_words = sentence.split()
        if word_count + len(sentence_words) > 100:
            break
        truncated_story.append(sentence)
        word_count += len(sentence_words)
    return ' '.join(truncated_story)

# Function to generate a caption from an image
def generate_caption(image_bytes, processor, caption_model):
    try:
        image = Image.open(io.BytesIO(image_bytes))
        inputs = processor(images=image, return_tensors="pt")
        outputs = caption_model.generate(**inputs)
        caption = processor.decode(outputs[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        st.error(f"Error generating caption: {e}")
        return None

# Function to generate a story from a caption
def generate_story(caption, story_generator):
    try:
        tokenizer = story_generator.tokenizer
        prompt = [{"role": "user", 
                   "content": f"Based on this text:'{caption}', elaborate a short story with no more than 100 words for young children."}]
        input_text = tokenizer.apply_chat_template(prompt, tokenize=False)
        
        story_output = story_generator(
            input_text,
            max_new_tokens=120,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            return_full_text=False
        )
        story = story_output[0]["generated_text"]
        story = truncate_to_100_words(story)
        return story
    except Exception as e:
        st.error(f"Error generating story: {e}")
        return None

# Function to generate audio from a story
def generate_audio(story, audio_pipeline):
    try:
        audio_generator = audio_pipeline(story, voice='af_heart', speed=1)
        audio_segments = []
        for i, (gs, ps, audio) in enumerate(audio_generator):
            audio_segments.append(audio)
        if not audio_segments:
            return None
        concatenated_audio = np.concatenate(audio_segments)
        audio_buffer = io.BytesIO()
        sf.write(audio_buffer, concatenated_audio, 24000, format='WAV')
        audio_buffer.seek(0)
        return audio_buffer
    except Exception as e:
        st.error(f"Error generating audio: {e}")
        return None

# Main function to encapsulate Streamlit logic
def main():
    # Load models
    processor, caption_model = load_caption_model()
    story_generator = load_story_model()
    audio_pipeline = load_audio_pipeline()

    # Streamlit UI
    st.title("A Cute Story Teller for Children \n🧸💐🦄🩰👼🐇🍄🍨🍩🧚‍♀️")
    st.write("✨Upload an image to generate a short children’s story with audio.✨")

    uploaded_file = st.file_uploader("Choose an image📷...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        image_bytes = uploaded_file.read()
        st.image(image_bytes, caption="Uploaded Image🐾", use_container_width=True)

        with st.spinner("Generating caption✍..."):
            caption = generate_caption(image_bytes, processor, caption_model)
        if caption:
            st.write("**Generated Caption🐈:**")
            st.write(caption)

            with st.spinner("Generating story📚..."):
                story = generate_story(caption, story_generator)
            if story:
                st.write("**Generated Story💐:**")
                st.write(story)

                with st.spinner("Generating audio💽..."):
                    audio_buffer = generate_audio(story, audio_pipeline)
                    if audio_buffer:
                        st.audio(audio_buffer, format="audio/wav")
                        st.download_button(
                            label="Download Story Audio🎀",
                            data=audio_buffer,
                            file_name="story_audio.wav",
                            mime="audio/wav"
                        )

if __name__ == "__main__":
    main()