Gabriel Okiri
Initial commit
4bb9d41
raw
history blame
1.25 kB
# 1. app/model/config.py
from dataclasses import dataclass
from typing import List, Optional
import torch
@dataclass
class ModelConfig:
model_name: str = "gpt2"
max_length: int = 128
batch_size: int = 16
learning_rate: float = 2e-5
num_train_epochs: int = 3
languages: List[str] = ("YORUBA", "IGBO", "HAUSA")
device: str = "cuda" if torch.cuda.is_available() else "cpu"
output_dir: str = "outputs"
# app/model/model.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .config import ModelConfig
class NigerianLanguageModel:
def __init__(self, config: ModelConfig):
self.config = config
self.setup_model()
def setup_model(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.config.model_name)
self._setup_special_tokens()
self.model.to(self.config.device)
def _setup_special_tokens(self):
special_tokens = {
"additional_special_tokens": [f"[{lang}]" for lang in self.config.languages]
}
self.tokenizer.add_special_tokens(special_tokens)
self.model.resize_token_embeddings(len(self.tokenizer))