VSDatta commited on
Commit
15f0a29
·
verified ·
1 Parent(s): d7064e7

Update src/ai_classifier.py

Browse files
Files changed (1) hide show
  1. src/ai_classifier.py +10 -8
src/ai_classifier.py CHANGED
@@ -1,13 +1,16 @@
 
1
  from transformers import pipeline
2
 
3
- # Load the zero-shot-classification pipeline and specify a cache directory.
4
- classifier = pipeline(
5
- "zero-shot-classification",
6
- model="joeddav/xlm-roberta-large-xnli",
7
- cache_dir="./hf_cache"
8
- )
 
 
 
9
 
10
- # Mapping of English category labels to Telugu.
11
  CATEGORIES = {
12
  "Family": "కుటుంబం",
13
  "Friendship": "స్నేహం",
@@ -32,7 +35,6 @@ CATEGORIES = {
32
  }
33
 
34
  def classify_proverb(text):
35
- """Classifies the proverb and returns the Telugu label."""
36
  result = classifier(text, list(CATEGORIES.keys()))
37
  top_label = result["labels"][0]
38
  return CATEGORIES[top_label]
 
1
+ import streamlit as st
2
  from transformers import pipeline
3
 
4
+ @st.cache_resource
5
+ def load_classifier():
6
+ return pipeline(
7
+ "zero-shot-classification",
8
+ model="joeddav/xlm-roberta-large-xnli", # or use a smaller model
9
+ cache_dir="./hf_cache"
10
+ )
11
+
12
+ classifier = load_classifier()
13
 
 
14
  CATEGORIES = {
15
  "Family": "కుటుంబం",
16
  "Friendship": "స్నేహం",
 
35
  }
36
 
37
  def classify_proverb(text):
 
38
  result = classifier(text, list(CATEGORIES.keys()))
39
  top_label = result["labels"][0]
40
  return CATEGORIES[top_label]