Riddhi Bhagwat commited on
Commit
4b9fc14
·
1 Parent(s): afc8109

auto detection of language input

Browse files
Files changed (3) hide show
  1. app/.DS_Store +0 -0
  2. app/app.py +24 -5
  3. app/lang_model_router.py +35 -0
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/app.py CHANGED
@@ -25,6 +25,7 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
25
  import threading
26
  from collections import defaultdict
27
  from datasets import load_dataset
 
28
 
29
 
30
  BASE_MODEL = os.getenv("MODEL", "google/gemma-3-12b-pt")
@@ -396,12 +397,29 @@ def respond(
396
  language: str,
397
  temperature: Optional[float] = None,
398
  seed: Optional[int] = None,
 
399
  ) -> list:
400
- """Respond to the user message with a system message
401
-
402
- Return the history with the new message"""
403
- messages = format_history_as_messages(history)
404
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  if ZERO_GPU:
406
  content = call_pipeline(messages)
407
  else:
@@ -416,6 +434,7 @@ def respond(
416
  )
417
  content = response.choices[0].message.content
418
 
 
419
  message = gr.ChatMessage(role="assistant", content=content)
420
  history.append(message)
421
  return history
 
25
  import threading
26
  from collections import defaultdict
27
  from datasets import load_dataset
28
+ from lang_model_router import detect_language_code, get_language_name_and_model
29
 
30
 
31
  BASE_MODEL = os.getenv("MODEL", "google/gemma-3-12b-pt")
 
397
  language: str,
398
  temperature: Optional[float] = None,
399
  seed: Optional[int] = None,
400
+ auto_detect: bool = True,
401
  ) -> list:
402
+ """Respond to the user message with system prompt in auto-detected or selected language."""
403
+ # Get last user message
404
+ user_input = ""
405
+ for msg in reversed(history):
406
+ if msg["role"] == "user":
407
+ user_input = msg["content"]
408
+ break
409
+
410
+ # Determine language
411
+ if auto_detect:
412
+ lang_code = detect_language_code(user_input)
413
+ language, _ = get_language_name_and_model(lang_code)
414
+
415
+ # Load system prompt
416
+ system_prompt = LANGUAGES.get(language, LANGUAGES["English"])
417
+
418
+ # Format message list with system prompt prepended
419
+ messages = [{"role": "system", "content": system_prompt}]
420
+ messages.extend(format_history_as_messages(history))
421
+
422
+ # Generate response
423
  if ZERO_GPU:
424
  content = call_pipeline(messages)
425
  else:
 
434
  )
435
  content = response.choices[0].message.content
436
 
437
+ # Add response to history
438
  message = gr.ChatMessage(role="assistant", content=content)
439
  history.append(message)
440
  return history
app/lang_model_router.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langdetect import detect, DetectorFactory
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
+ DetectorFactory.seed = 0
6
+ LANGUAGE_MAP = {
7
+ "en": {"name": "English", "model": "openai-community/gpt2"},
8
+ "fr": {"name": "French", "model": "dbddv01/gpt2-french-small"},
9
+ "es": {"name": "Spanish", "model": "datificate/gpt2-small-spanish"},
10
+ "de": {"name": "German", "model": "deepset/gbert-base"},
11
+ "hi": {"name": "Hindi", "model": "ai4bharat/indic-bert"},
12
+ "mr": {"name": "Marathi", "model": "ai4bharat/indic-bert"},
13
+ "ja": {"name": "Japanese", "model": "rinna/japanese-gpt2-medium"},
14
+ "zh-cn": {"name": "Chinese", "model": "uer/gpt2-chinese-cluecorpusswwm"},
15
+ "ru": {"name": "Russian", "model": "sberbank-ai/rugpt3small_based_on_gpt2"},
16
+ "pt": {"name": "Portuguese", "model": "pierreguillou/gpt2-small-portuguese"},
17
+ "it": {"name": "Italian", "model": "dbddv01/gpt2-italian"},
18
+ "nl": {"name": "Dutch", "model": "GroNLP/gpt2-small-dutch"}
19
+ }
20
+
21
+
22
+ def detect_language_code(text: str) -> str:
23
+ try:
24
+ return detect(text)
25
+ except Exception:
26
+ return "en" # fallback
27
+
28
+ def get_language_name_and_model(lang_code: str) -> tuple[str, str]:
29
+ return LANGUAGE_MAP.get(lang_code, LANGUAGE_MAP["en"])
30
+
31
+ def get_model_by_name(language_name: str) -> str:
32
+ for code, (name, model) in LANGUAGE_MAP.items():
33
+ if name.lower() == language_name.lower():
34
+ return model
35
+ return LANGUAGE_MODEL_MAP["en"][1]