momenaca's picture
add features to ease hackathon
f07b5e8
raw
history blame
3.01 kB
from msal import ConfidentialClientApplication
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain_groq import ChatGroq
from langchain.vectorstores.azuresearch import AzureSearch
import os
class LLM:
def __init__(self, llm):
self.llm = llm
self.callbacks = []
def stream(self, prompt, prompt_arguments):
self.llm.streaming = True
streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments))
output = ""
for op in streamed_content:
output += op.content
yield output
def get_prediction(self, prompt, prompt_arguments):
self.llm.callbacks = self.callbacks
return self.llm.predict_messages(
prompt.format_messages(**prompt_arguments)
).content
async def get_aprediction(self, prompt, prompt_arguments):
self.llm.callbacks = self.callbacks
prediction = await self.llm.apredict_messages(
prompt.format_messages(**prompt_arguments)
)
return prediction
async def get_apredictions(self, prompts, prompts_arguments):
self.llm.callbacks = self.callbacks
predictions = []
for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments):
prediction = await self.llm.apredict_messages(
prompts[prompt_].format_messages(**prompt_args_)
)
predictions.append(prediction.content)
return predictions
def get_llm_api(groq_model_name):
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"):
print("Using Azure OpenAI API")
return LLM(
AzureChatOpenAI(
deployment_name=os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"),
openai_api_key=os.getenv("EKI_OPENAI_API_KEY"),
azure_endpoint=os.getenv("EKI_OPENAI_LLM_API_ENDPOINT"),
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"),
streaming=True,
temperature=0,
max_tokens=2048,
stop=["<|im_end|>"],
)
)
else:
print("Using GROQ API")
return LLM(
ChatGroq(
model=groq_model_name,
temperature=0,
max_tokens=2048,
)
)
def get_vectorstore_api(index_name):
aoai_embeddings = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
azure_deployment=os.getenv("EKI_OPENAI_EMB_DEPLOYMENT_NAME"),
api_key=os.getenv("EKI_OPENAI_API_KEY"),
azure_endpoint=os.environ["EKI_OPENAI_EMB_API_ENDPOINT"],
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"),
)
vector_store: AzureSearch = AzureSearch(
azure_search_endpoint=os.getenv("EKI_VECTOR_STORE_ADDRESS"),
azure_search_key=os.getenv("EKI_VECTOR_STORE_PASSWORD"),
index_name=index_name,
embedding_function=aoai_embeddings.embed_query,
)
return vector_store