Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,105 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
|
|
|
|
3 |
from transformers import pipeline
|
|
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
def question_answer(context, question):
|
8 |
-
result = question_answerer(question=question, context=context)
|
9 |
-
return result['answer'],result['score']
|
10 |
|
11 |
-
gr.Interface(fn=question_answer, inputs=["text", "text"], outputs=[gr.components.Textbox(label="Answer"), gr.components.Textbox(label="Probability")]).launch()
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
|
3 |
import gradio as gr
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
from transformers import pipeline
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
+
model_name = "9wimu9/mt5-xl-sin-odqa-1"
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
question_answerer_seq_2seq = model.to(device)
|
15 |
+
|
16 |
+
retriever_model = SentenceTransformer('9wimu9/retriever-model-sinhala-v2')
|
17 |
+
question_answerer = pipeline("question-answering", model='9wimu9/xlm-roberta-large-en-si-only-finetuned-sinquad-v12')
|
18 |
+
|
19 |
+
def srq2seq_find_answer(query,context):
|
20 |
+
conditioned_doc = "<P> " + " <P> ".join([d for d in context])
|
21 |
+
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
|
22 |
+
|
23 |
+
model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
|
24 |
+
|
25 |
+
generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
|
26 |
+
attention_mask=model_input["attention_mask"].to(device),
|
27 |
+
min_length=2,
|
28 |
+
max_length=120,
|
29 |
+
early_stopping=True,
|
30 |
+
num_beams=9,
|
31 |
+
temperature=0.9,
|
32 |
+
do_sample=False,
|
33 |
+
top_k=None,
|
34 |
+
top_p=None,
|
35 |
+
eos_token_id=tokenizer.eos_token_id,
|
36 |
+
no_repeat_ngram_size=8,
|
37 |
+
num_return_sequences=1)
|
38 |
+
return tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]
|
39 |
+
|
40 |
+
def encode_file(file_path):
|
41 |
+
passages = []
|
42 |
+
with open(file_path) as file:
|
43 |
+
for item in file:
|
44 |
+
passages.append([item])
|
45 |
+
df = pd.DataFrame(passages, columns = ['text'])
|
46 |
+
text = df['text']
|
47 |
+
vectors = retriever_model.encode(text)
|
48 |
+
return vectors,passages
|
49 |
+
|
50 |
+
def upload_file(files):
|
51 |
+
global index
|
52 |
+
global passages
|
53 |
+
file_paths = [file.name for file in files]
|
54 |
+
vectors,passages = encode_file(file_paths[0])
|
55 |
+
vector_dimension = vectors.shape[1]
|
56 |
+
index = faiss.IndexFlatL2(vector_dimension)
|
57 |
+
faiss.normalize_L2(vectors)
|
58 |
+
index.add(vectors)
|
59 |
+
return file_paths
|
60 |
+
|
61 |
+
def question_answer(search_text):
|
62 |
+
search_vector = retriever_model.encode(search_text)
|
63 |
+
print(search_vector)
|
64 |
+
_vector = np.array([search_vector])
|
65 |
+
faiss.normalize_L2(_vector)
|
66 |
+
k = index.ntotal
|
67 |
+
distances, ann = index.search(_vector, k=k)
|
68 |
+
context = passages[ann[0][0]][0]
|
69 |
+
result = question_answerer(question=search_text, context=context)
|
70 |
+
print(result)
|
71 |
+
return result['answer']
|
72 |
+
|
73 |
+
|
74 |
+
def question_answer_generated(search_text):
|
75 |
+
search_vector = retriever_model.encode(search_text)
|
76 |
+
print(search_vector)
|
77 |
+
_vector = np.array([search_vector])
|
78 |
+
faiss.normalize_L2(_vector)
|
79 |
+
k = index.ntotal
|
80 |
+
distances, ann = index.search(_vector, k=k)
|
81 |
+
context = passages[ann[0][0]][0]
|
82 |
+
return srq2seq_find_answer(search_text,[context])
|
83 |
+
|
84 |
+
with gr.Blocks() as demo:
|
85 |
+
with gr.Row():
|
86 |
+
with gr.Column():
|
87 |
+
file_output = gr.File()
|
88 |
+
upload_button = gr.UploadButton("Click to Upload a File", file_types=["txt"], file_count="1")
|
89 |
+
upload_button.upload(upload_file, upload_button, file_output)
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column():
|
92 |
+
name = gr.Textbox(label="question")
|
93 |
+
output = gr.Textbox(label="answer")
|
94 |
+
greet_btn = gr.Button("get Answer - extraction QA")
|
95 |
+
greet_btn.click(fn=question_answer, inputs=name, outputs=output, api_name="greet")
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
name = gr.Textbox(label="question")
|
99 |
+
output = gr.Textbox(label="answer")
|
100 |
+
greet_btn = gr.Button("get Answer - Generated QA")
|
101 |
+
greet_btn.click(fn=question_answer_generated, inputs=name, outputs=output, api_name="greet")
|
102 |
+
|
103 |
+
demo.launch(debug=True)
|
104 |
|
|
|
|
|
|
|
105 |
|
|