File size: 2,093 Bytes
59e1b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-small")
headline = AutoModelForSeq2SeqLM.from_pretrained("wetey/content-summarizer")
generate_long = AutoModelForSeq2SeqLM.from_pretrained("wetey/content-generator")

def generate_headline(text):
    inputs = tokenizer(text, return_tensors="pt").input_ids
    
    generation_config = GenerationConfig(temperature = 1.2, 
                                         encoder_no_repeat_ngram_size = 4)
    
    outputs = headline.generate(inputs, 
                                do_sample = True, 
                                generation_config = generation_config)
    
    return tokenizer.decode(outputs[0], skip_special_tokens = True)

def generate_content(text):
    inputs = tokenizer(text, return_tensors="pt").input_ids
    generation_config = GenerationConfig(temperature = 1.2, 
                                         encoder_no_repeat_ngram_size = 2,
                                         min_length = 50, 
                                         max_length = 512, 
                                         length_penalty = 1.5, 
                                         num_beams = 4,
                                         repetition_penalty = 1.5,
                                         no_repeat_ngram_size = 3)
    outputs = generate_long.generate(inputs, 
                                     do_sample = True, 
                                     generation_config = generation_config)
    
    return tokenizer.decode(outputs[0], skip_special_tokens = True)    
      
textbox = gr.Textbox(label="Type your text here", lines=2) 

demo = gr.Blocks()

with demo:
    text_input = gr.Textbox()
    text_output = gr.Textbox()
    
    b1 = gr.Button("Generate headline")
    b2 = gr.Button("Generate long content")

    b1.click(generate_headline, inputs=text_input, outputs=text_output)
    b2.click(generate_content, inputs=text_input, outputs=text_output)

demo.launch()