TD-Llama-General

TL;DR

TD (ToolDial)-Llama-General is a version of the LLaMA3.1-8B-Instruct model that has been instruction-tuned on the ToolDial dataset in the ReAct format (Thought-Action-Observation). Unlike the original ToolDial paper, dialogue states are removed from the reasoning trace, and the API call payload is concatenated within the reasoning step for Call actions.

This model can be used as,

  • Baseline model for tool calling (function calling), and multi-turn dialogue agent task.
  • Warm up starting model for SFT, RL for Tool Agent training.

[Caution]

  • To use the model for inference, it is important to carefully follow the Input Format Example and use the instruction provided under Instruction for TD-Llama-General.
  • Refer the Example Code below.

[Model Summary]

  • Trained with Q-lora quantization. Our quantization configuration is in the Example Code.
  • Trained for 1 epoch with Adam-8bit optimizer with learning rate 0.00001 and beta 0.9 to 0.995.
  • We only provide trained LoRa adapter. Follow the Example Code and you can use it easily.

Input Format Example

User: I need to analyze the intraday performance of a cryptocurrency pair and compare it with the intraday data of an FX currency pair I have already retrieved. Can you provide the detailed intraday time series data for the cryptocurrency pair?
System:
- Thought: The user seems to have intent. I will call the retriever.
- Action:retriever_call
- Observation: {'retriever_status': {'retriever_call': 'true', 'retrieved_api': {'Finance|quotient|/crypto/intraday': 0.71542144, 'Finance|eod_historical_data|Intraday Historical Data API': 0.67002374, 'Finance|quotient|/forex/daily': 0.6112553, 'Finance|quotient|/forex/intraday': 0.5370315, 'Finance|real_time_finance_data|Currency Time Series': 0.50981194}}, 'api_documentation': {'api_name': 'Finance|quotient|/crypto/intraday', 'api_description': 'Return intraday time series (Date, Open, High, Low, Close, Volume) given the input parameters.', 'required_parameters': [{'input_parameter_name': 'interval', 'description': 'the time interval for retrieving intraday market data, specifying the duration over which the intraday time series data (Date, Open, High, Low, Close, Volume) is requested.'}, {'input_parameter_name': 'to', 'description': "The end date for the intraday time series query is the specific date that marks the conclusion of the requested intraday data period for various assets, including cryptocurrencies, as defined by the API's functionality."}, {'input_parameter_name': 'symbol', 'description': "The symbol of the cryptocurrency pair for which intraday market data is requested, representing the specific digital assets being analyzed in the context of the API's functionality."}, {'input_parameter_name': 'from', 'description': 'The start date for the intraday time series is the specific date from which the intraday market data, including open, high, low, close, and volume, will be retrieved for various assets such as cryptocurrencies.'}], 'optional_parameters': []}}
- Thought: Based on results of the retriever, I need to call Finance|quotient|/crypto/intraday. To call it, I need to ask 'interval', 'to', 'symbol', and 'from' to the user.
- Action:request
- Observation: {}
- Message: Please tell me the interval, to, symbol, and from for the cryptocurrency pair.
User: Certainly! The interval is 15min, the to date is 2023-10-01 15:00, the symbol is ETH-USD, and the from date is 2023-10-01 09:00.
System:
- Thought: Based on the user's response, now I can call the API Finance|quotient|/crypto/intraday.
- Action:call{'api_name': 'Finance|quotient|/crypto/intraday', 'required_parameters': {'interval': '15min', 'to': '2023-10-01 15:00', 'symbol': 'ETH-USD', 'from': '2023-10-01 09:00'}, 'optional_parameters': {}}
- Observation: {'call_result': {'api_name': 'Finance|quotient|/crypto/intraday', 'result': {'symbol': 'ETH-USD', 'data': [{'date': '2023-10-01 09:00', 'open': 1800.0, 'high': 1820.0, 'low': 1790.0, 'close': 1815.0, 'volume': 1500}, {'date': '2023-10-01 09:15', 'open': 1815.0, 'high': 1830.0, 'low': 1805.0, 'close': 1820.0, 'volume': 1200}, {'date': '2023-10-01 09:30', 'open': 1820.0, 'high': 1840.0, 'low': 1810.0, 'close': 1835.0, 'volume': 1300}, {'date': '2023-10-01 09:45', 'open': 1835.0, 'high': 1850.0, 'low': 1825.0, 'close': 1845.0, 'volume': 1400}, {'date': '2023-10-01 10:00', 'open': 1845.0, 'high': 1860.0, 'low': 1835.0, 'close': 1855.0, 'volume': 1600}, {'date': '2023-10-01 10:15', 'open': 1855.0, 'high': 1870.0, 'low': 1845.0, 'close': 1865.0, 'volume': 1500}, {'date': '2023-10-01 10:30', 'open': 1865.0, 'high': 1880.0, 'low': 1855.0, 'close': 1875.0, 'volume': 1200}, {'date': '2023-10-01 10:45', 'open': 1875.0, 'high': 1890.0, 'low': 1865.0, 'close': 1885.0, 'volume': 1300}, {'date': '2023-10-01 11:00', 'open': 1885.0, 'high': 1900.0, 'low': 1875.0, 'close': 1895.0, 'volume': 1400}, {'date': '2023-10-01 11:15', 'open': 1895.0, 'high': 1910.0, 'low': 1885.0, 'close': 1905.0, 'volume': 1600}, {'date': '2023-10-01 11:30', 'open': 1905.0, 'high': 1920.0, 'low': 1895.0, 'close': 1915.0, 'volume': 1500}, {'date': '2023-10-01 11:45', 'open': 1915.0, 'high': 1930.0, 'low': 1905.0, 'close': 1925.0, 'volume': 1200}, {'date': '2023-10-01 12:00', 'open': 1925.0, 'high': 1940.0, 'low': 1915.0, 'close': 1935.0, 'volume': 1300}, {'date': '2023-10-01 12:15', 'open': 1935.0, 'high': 1950.0, 'low': 1925.0, 'close': 1945.0, 'volume': 1400}, {'date': '2023-10-01 12:30', 'open': 1945.0, 'high': 1960.0, 'low': 1935.0, 'close': 1955.0, 'volume': 1600}, {'date': '2023-10-01 12:45', 'open': 1955.0, 'high': 1970.0, 'low': 1945.0, 'close': 1965.0, 'volume': 1500}, {'date': '2023-10-01 13:00', 'open': 1965.0, 'high': 1980.0, 'low': 1955.0, 'close': 1975.0, 'volume': 1200}, {'date': '2023-10-01 13:15', 'open': 1975.0, 'high': 1990.0, 'low': 1965.0, 'close': 1985.0, 'volume': 1300}, {'date': '2023-10-01 13:30', 'open': 1985.0, 'high': 2000.0, 'low': 1975.0, 'close': 1995.0, 'volume': 1400}, {'date': '2023-10-01 13:45', 'open': 1995.0, 'high': 2010.0, 'low': 1985.0, 'close': 2005.0, 'volume': 1600}, {'date': '2023-10-01 14:00', 'open': 2005.0, 'high': 2020.0, 'low': 1995.0, 'close': 2015.0, 'volume': 1500}, {'date': '2023-10-01 14:15', 'open': 2015.0, 'high': 2030.0, 'low': 2005.0, 'close': 2025.0, 'volume': 1200}, {'date': '2023-10-01 14:30', 'open': 2025.0, 'high': 2040.0, 'low': 2015.0, 'close': 2035.0, 'volume': 1300}, {'date': '2023-10-01 14:45', 'open': 2035.0, 'high': 2050.0, 'low': 2025.0, 'close': 2045.0, 'volume': 1400}], 'interval': 15, 'from': '2023-10-01 09:00', 'to': '2023-10-01 15:00'}}}
- Thought: Based on the results of Finance|quotient|/crypto/intraday, I can respond to the user.
- Action:response
- Observation: {}
- Message: Here is the detailed intraday time series data for the cryptocurrency pair ETH-USD from 2023-10-01 09:00 to 2023-10-01 15:00 with a 15-minute interval. The data includes open, high, low, close prices, and volume for each interval.
User: Thank you for the detailed information. Goodbye!
- Thought: The user has said goodbye. I will respond politely.
- Action:system_bye
- Observation: {}
- Message: You're welcome! Have a great day! Goodbye!

Instruction for TD-Llama-General

Instruction:
You will be provided with part of a dialogue between the user and the system. In this dialogue, the user is requesting information from the system, and the system will make an API call to retrieve the necessary information.
Your task is to generate an appropriate reasoning, and action based on the dialogue history and the user's most recent utterance. Additionally, generate an appropriate response message for the user's last utterance.

---------------Refer the internal rule----------------------

<Retriever status rules>
1. The system selects the API with the highest score from among the APIs in the retriever status that have a score of 0.6 or higher, which is suitable for processing the user's query.
2. If no APIs score is higher than 0.6, system cannot confirm the API to call.

<Action list>
The actions that the system can take are as follows.
- Request: In a situation where the API to be executed has been determined, ask the user questions to gather information on the input parameters.
- Response: Reply to user's request based on the result of API call.
- Clarify: If user's query is vague, re-ask user to get intent specifically. If there is no API in the most recently retrieved results with a score above 0.5, "clarify" is required.
- Suggest: Making a suggestion for an unclear user's intent and ask-ing whether it satisfies the user. If there is an API in the most recently retrieved results with a score above 0.5 but none exceeding 0.6, a 'suggest' action is required.
- Response fail: Notify the user that the system cannot execute the request due to insufficient information.
- System bye: System says goodbye to user politely.
- Call: Call the API with collected information from user or else and don't reply to user yet.
(Unlike other actions, Call must be generated by concatenating it with the API call format. The required format is as follows.)
Call{'api_name': 'API1', 'required_parameters': {'param1': 'value1', 'param2': 'value2'}, 'optional_parameters': {'param3': 'value3'}
- Retriever call: Call the retriever to find proper API to satisfy user's requests. The system should call the retriever in the following two situations:
1. When the user specifies a requirement, and the system needs to search for an API to fulfill it.
2. When the user does not provide input parameters required for an API call, and the system needs to search for another API to obtain those parameters.

----------------------------------

Based on the rules above, you will get an input as below:
Input:
<Dialogue history>
User: <User's last utterance>

And here is the format for what you should return:
- Thought: <System's thought process. You generate here>
- Action: <System's action. You generate here. Choose one from <Action list>>
- Observation: <This will be provided externally.>
- Thought: <System's thought process. You generate here, refering the dialogue history and the observation above>
- Action: <System's action. You generate here, refering the dialogue history and the observation above. Choose one from <Action list>>
- Observation: <This will be provided externally.>
- Message: <System's message>

Output Format Rules:
1. Multiple reasoning steps (Thought-Action pairs) may be present in the output if required
2. The call format for the Call action must be constructed by referring to the given API documentation and dialogue history.

Now, generate appropriate reasoning trace and responding message to user.

Dialogue History:

Example Code

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
import json,re,torch
from peft import PeftModel, PeftConfig

def normalize_action(text):
    lower_text = text.lower()
    normalized_text = re.sub(r'[^a-z]', '', lower_text)
    return normalized_text.replace("action","").replace("api","")

def extract_observation(label):
    lines = label.split("\n")
    return_label = {}
    for idx,sentence in enumerate(lines):
        if "- Action:" in sentence:
            current_obs = lines[idx+1].replace("- Observation: ","")
            current_action = normalize_action(sentence.replace("- Action:","").strip())
            if current_action in return_label:
                return_label[current_action].insert(0,current_obs)
            else:
                return_label[current_action] = [current_obs]
    return return_label

def overall_form(dialogue_idx,full_dialogue,tokenizer_inf,is_val=False):
    prompt = ""
    lines = full_dialogue['dialogue'].split("\n")
    for idx,sentence in enumerate(lines):
        if "- Dialogue State: " in sentence:
            continue

        if "- Action:call" in sentence:
            if "- Dialogue State: " in lines[idx-2]:
                prompt+=sentence+str(eval(lines[idx-2].replace("- Dialogue State: ",""))['api_status'])+"\n"
            else:
                find_ds_idx=1
                while "- Dialogue State: " not in lines[idx+find_ds_idx]:
                    find_ds_idx+=1
                prompt+=sentence+str(eval(lines[idx+find_ds_idx].replace("- Dialogue State: ",""))['api_status'])+"\n"
        else:
            prompt+=sentence+"\n"
    
    lines = prompt.split("\n")
    label = []
    prompt = ""
    for idx,sentence in enumerate(lines):
        prompt+=sentence+"\n"
        if "User: " in sentence:
            current_dial = {"dial":prompt,"label":""}
            tmp_idx = 1
            while "User: " not in lines[idx+tmp_idx]:
                current_dial["label"]+=lines[idx+tmp_idx]+"\n"
                tmp_idx+=1
                if idx+tmp_idx==len(lines):
                    break
            current_dial['dial']+="System:"
            current_dial['label'] = current_dial['label'].replace("System:\n","")
            if is_val: ## test
                chat_eval = [{"role": "user", "content":overall_evaluation_gen_prompt+prompt}] #
                template_chat = tokenizer_inf.apply_chat_template(chat_eval, tokenize=False, add_generation_prompt=True)
                label.append({"dialogue_idx":dialogue_idx,
                              "dial":template_chat,
                              "label":current_dial['label'],
                              "obs":extract_observation(current_dial['label'])})
            else: ## train
                history_chat = [{"role": "user", "content": overall_evaluation_gen_prompt+prompt}, {"role": "assistant", "content": ""}]
                chat = [{"role": "user", "content": overall_evaluation_gen_prompt+prompt}, {"role": "assistant", "content": current_dial['label']}]
                template_history = tokenizer_inf.apply_chat_template(history_chat, tokenize=False, add_generation_prompt=False)
                template_chat = tokenizer_inf.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
                
                label.append({
                    "dialogue_idx":dialogue_idx,
                    "history": template_history.replace("\n<|eot_id|>", "\n"),
                    "dial": template_chat,
                    "label": current_dial['label']})
    return label

device = "cuda:0"
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    ),
    device_map={"": device},
)
tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
model = PeftModel.from_pretrained(base_model, "HOLILAB/td-llama-general")

with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
    dial_list = json.load(f)

sample_dialogue_idx = 200
turn_idx = 2

input_prompt=overall_form(sample_dialogue_idx,dial_list[sample_dialogue_idx],tokenizer,is_val=True)[turn_idx]['dial']

inputs = tokenizer(input_prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=1000,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support