9wimu9 commited on
Commit
a182034
·
1 Parent(s): ee37e1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -6
app.py CHANGED
@@ -1,11 +1,105 @@
 
 
1
  import gradio as gr
2
- import constants
 
 
3
  from transformers import pipeline
 
4
 
5
- question_answerer = pipeline("question-answering", model=constants.MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def question_answer(context, question):
8
- result = question_answerer(question=question, context=context)
9
- return result['answer'],result['score']
10
 
11
- gr.Interface(fn=question_answer, inputs=["text", "text"], outputs=[gr.components.Textbox(label="Answer"), gr.components.Textbox(label="Probability")]).launch()
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
3
  import gradio as gr
4
+ import faiss
5
+ import numpy as np
6
+ import pandas as pd
7
  from transformers import pipeline
8
+ from sentence_transformers import SentenceTransformer
9
 
10
+ model_name = "9wimu9/mt5-xl-sin-odqa-1"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ question_answerer_seq_2seq = model.to(device)
15
+
16
+ retriever_model = SentenceTransformer('9wimu9/retriever-model-sinhala-v2')
17
+ question_answerer = pipeline("question-answering", model='9wimu9/xlm-roberta-large-en-si-only-finetuned-sinquad-v12')
18
+
19
+ def srq2seq_find_answer(query,context):
20
+ conditioned_doc = "<P> " + " <P> ".join([d for d in context])
21
+ query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
22
+
23
+ model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
24
+
25
+ generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
26
+ attention_mask=model_input["attention_mask"].to(device),
27
+ min_length=2,
28
+ max_length=120,
29
+ early_stopping=True,
30
+ num_beams=9,
31
+ temperature=0.9,
32
+ do_sample=False,
33
+ top_k=None,
34
+ top_p=None,
35
+ eos_token_id=tokenizer.eos_token_id,
36
+ no_repeat_ngram_size=8,
37
+ num_return_sequences=1)
38
+ return tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]
39
+
40
+ def encode_file(file_path):
41
+ passages = []
42
+ with open(file_path) as file:
43
+ for item in file:
44
+ passages.append([item])
45
+ df = pd.DataFrame(passages, columns = ['text'])
46
+ text = df['text']
47
+ vectors = retriever_model.encode(text)
48
+ return vectors,passages
49
+
50
+ def upload_file(files):
51
+ global index
52
+ global passages
53
+ file_paths = [file.name for file in files]
54
+ vectors,passages = encode_file(file_paths[0])
55
+ vector_dimension = vectors.shape[1]
56
+ index = faiss.IndexFlatL2(vector_dimension)
57
+ faiss.normalize_L2(vectors)
58
+ index.add(vectors)
59
+ return file_paths
60
+
61
+ def question_answer(search_text):
62
+ search_vector = retriever_model.encode(search_text)
63
+ print(search_vector)
64
+ _vector = np.array([search_vector])
65
+ faiss.normalize_L2(_vector)
66
+ k = index.ntotal
67
+ distances, ann = index.search(_vector, k=k)
68
+ context = passages[ann[0][0]][0]
69
+ result = question_answerer(question=search_text, context=context)
70
+ print(result)
71
+ return result['answer']
72
+
73
+
74
+ def question_answer_generated(search_text):
75
+ search_vector = retriever_model.encode(search_text)
76
+ print(search_vector)
77
+ _vector = np.array([search_vector])
78
+ faiss.normalize_L2(_vector)
79
+ k = index.ntotal
80
+ distances, ann = index.search(_vector, k=k)
81
+ context = passages[ann[0][0]][0]
82
+ return srq2seq_find_answer(search_text,[context])
83
+
84
+ with gr.Blocks() as demo:
85
+ with gr.Row():
86
+ with gr.Column():
87
+ file_output = gr.File()
88
+ upload_button = gr.UploadButton("Click to Upload a File", file_types=["txt"], file_count="1")
89
+ upload_button.upload(upload_file, upload_button, file_output)
90
+ with gr.Row():
91
+ with gr.Column():
92
+ name = gr.Textbox(label="question")
93
+ output = gr.Textbox(label="answer")
94
+ greet_btn = gr.Button("get Answer - extraction QA")
95
+ greet_btn.click(fn=question_answer, inputs=name, outputs=output, api_name="greet")
96
+ with gr.Row():
97
+ with gr.Column():
98
+ name = gr.Textbox(label="question")
99
+ output = gr.Textbox(label="answer")
100
+ greet_btn = gr.Button("get Answer - Generated QA")
101
+ greet_btn.click(fn=question_answer_generated, inputs=name, outputs=output, api_name="greet")
102
+
103
+ demo.launch(debug=True)
104
 
 
 
 
105