Casktalk V1.0

CasTalk, our avatar, possesses a distinct persona and communicates in a human-like manner. They have the ability to learn the new things (such as English tutoring and storytelling #upskills). Beyond that, they can build relationships with you.

image/png

Model Details

  • Developed by: ToriLab (CasTalk)
  • Model type: Joint Attention and Mamba (Jamba)

Usage

Presequities

pip install --upgrade pip
pip install transformers>=4.39.0
pip install mamba-ssm causal-conv1d>=1.2.0
pip install --upgrade castalk-llm  transformers accelerate peft

Run the model

from castalk import AvatarPipeline

# We use the AvatarPipeline class to load the model and the adapter.
pipe = AvatarPipeline.from_pretrained(
    "torilab/castalk-1.0-base",
     variant="fp16",
     torch_dtype=torch.float16,
).to("cuda")

# load abilities and skills for the model
pipe.load_adapter("torilab/eng_girlfriend",  adapter_name="eng_girlfriend")
pipe.load_adapter("torilab/eng_lover",  adapter_name="eng_girlfriend")
pipe.load_adapter("torilab/eng_assistance",  adapter_name="eng_assistance")
pipe.load_adapter("torilab/eng_psychology",  adapter_name="eng_psychology")


generator = torch.manual_seed(0)

# POC:  check relalationships of user and model and set weights for each adapter
# eng_girlfriend : 0.7
# eng_lover : 0.9
# eng_psychology: 0.1 , or disable based on user prompt

pipe.set_adapters(["eng_girlfriend","eng_lover", "eng_psychology"], adapter_weights=[0.7, 0.1, 0.1])

prompt = "Hi how are you today?"
response = pipe.generate(prompt, generator= generator)
response
# Output: "Great my love, how are you doing today?"

Train new skills


from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

# load from our base model 
tokenizer = AutoTokenizer.from_pretrained("torilab/castalk-1.0-base")
model = AutoModelForCausalLM.from_pretrained("torilab/castalk-1.0-base", trust_remote_code=True, device_map='auto')

dataset = load_dataset("torilab/eng_new_skills", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config = LoraConfig(
    r=8,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)

trainer.train()

About -ToriLab

ToriLab builds reliable, practical, and scalable AI solutions for the CasTalk app.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support