chansung's picture
Update app.py
0cff7d6
import io
import base64
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import gradio as gr
from datasets import load_dataset
from datasets import DownloadMode, VerificationMode
STYLES = """
#container {
margin: auto;
width: 50%;
}
#gallery {
height: 500px !important;
}
.center {
text-align: center;
}
.small-big {
font-size: 12pt !important;
}
"""
titles = []
stories = []
def add_title(image, title):
dr = ImageDraw.Draw(image)
myFont = ImageFont.truetype('arial_bold.ttf', 30)
_, _, w, h = dr.textbbox((0, 0), title, font=myFont)
dr.rectangle([(0, image.height-80), (image.width, (image.height-80)+h)], fill="white", outline="white")
dr.text(((image.width-w)/2, image.height-80), title, font=myFont, fill=(0, 0, 0))
return image
def gallery_select(gallery, evt: gr.SelectData):
print(evt.value)
print(evt.index)
print(evt.target)
return [
gr.update(value=f"## {titles[evt.index]}", visible=True),
gr.update(value=stories[evt.index], visible=True),
]
def get_gallery():
global titles, stories
images = []
titles = []
stories = []
dataset = load_dataset(
"chansung/llama2-stories",
download_mode=DownloadMode.FORCE_REDOWNLOAD,
verification_mode=VerificationMode.NO_CHECKS
)
for row in dataset['train']:
try:
base64_image = row['image']
base64_decoded = base64.b64decode(base64_image)
image = Image.open(io.BytesIO(base64_decoded))
except:
image = Image.open('placeholder.png')
titles.append(row['title'])
stories.append(row['story'])
images.append(add_title(image, row['title']))
return images
with gr.Blocks(css=STYLES) as demo:
with gr.Column(elem_id="container"):
gr.Markdown("## LLaMA2 Story Showcase", elem_classes=['center'])
gr.Markdown("This space is where community shares generated stories by [chansung/co-write-with-llama2](https://huggingface.co/spaces/chansung/co-write-with-llama2) space. "
"Generated stories are archived in [chansung/llama2-stories](https://huggingface.co/datasets/chansung/llama2-stories) dataset repository. The gallery will be "
"regularly updated in a daily basis.",
elem_classes=['small-big', 'center'])
gallery = gr.Gallery(get_gallery, every=3000, columns=5, container=False, elem_id="gallery")
with gr.Column():
title = gr.Markdown("title", visible=False, elem_classes=['center'])
story = gr.Markdown("stories goes here...", visible=False, elem_classes=['small-big'])
gallery.select(
fn=gallery_select,
inputs=[gallery],
outputs=[title, story]
)
demo.launch()