Spaces:
Sleeping
Sleeping
from msal import ConfidentialClientApplication | |
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI | |
from langchain_groq import ChatGroq | |
from langchain.vectorstores.azuresearch import AzureSearch | |
import os | |
class LLM: | |
def __init__(self, llm): | |
self.llm = llm | |
self.callbacks = [] | |
def stream(self, prompt, prompt_arguments): | |
self.llm.streaming = True | |
streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments)) | |
output = "" | |
for op in streamed_content: | |
output += op.content | |
yield output | |
def get_prediction(self, prompt, prompt_arguments): | |
self.llm.callbacks = self.callbacks | |
return self.llm.predict_messages( | |
prompt.format_messages(**prompt_arguments) | |
).content | |
async def get_aprediction(self, prompt, prompt_arguments): | |
self.llm.callbacks = self.callbacks | |
prediction = await self.llm.apredict_messages( | |
prompt.format_messages(**prompt_arguments) | |
) | |
return prediction | |
async def get_apredictions(self, prompts, prompts_arguments): | |
self.llm.callbacks = self.callbacks | |
predictions = [] | |
for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments): | |
prediction = await self.llm.apredict_messages( | |
prompts[prompt_].format_messages(**prompt_args_) | |
) | |
predictions.append(prediction.content) | |
return predictions | |
def get_llm_api(groq_model_name): | |
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): | |
print("Using Azure OpenAI API") | |
return LLM( | |
AzureChatOpenAI( | |
deployment_name=os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"), | |
openai_api_key=os.getenv("EKI_OPENAI_API_KEY"), | |
azure_endpoint=os.getenv("EKI_OPENAI_LLM_API_ENDPOINT"), | |
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"), | |
streaming=True, | |
temperature=0, | |
max_tokens=2048, | |
stop=["<|im_end|>"], | |
) | |
) | |
else: | |
print("Using GROQ API") | |
return LLM( | |
ChatGroq( | |
model=groq_model_name, | |
temperature=0, | |
max_tokens=2048, | |
) | |
) | |
def get_vectorstore_api(index_name): | |
aoai_embeddings = AzureOpenAIEmbeddings( | |
model="text-embedding-ada-002", | |
azure_deployment=os.getenv("EKI_OPENAI_EMB_DEPLOYMENT_NAME"), | |
api_key=os.getenv("EKI_OPENAI_API_KEY"), | |
azure_endpoint=os.environ["EKI_OPENAI_EMB_API_ENDPOINT"], | |
openai_api_version=os.getenv("EKI_OPENAI_API_VERSION"), | |
) | |
vector_store: AzureSearch = AzureSearch( | |
azure_search_endpoint=os.getenv("EKI_VECTOR_STORE_ADDRESS"), | |
azure_search_key=os.getenv("EKI_VECTOR_STORE_PASSWORD"), | |
index_name=index_name, | |
embedding_function=aoai_embeddings.embed_query, | |
) | |
return vector_store | |