File size: 4,045 Bytes
38812af
 
ff3806f
 
38812af
ff3806f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38812af
 
 
 
ff3806f
38812af
 
 
 
 
 
 
 
ff3806f
7fd0db3
ff3806f
38812af
 
 
 
 
 
 
 
 
ff3806f
 
 
 
 
 
 
 
 
38812af
8507438
38812af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3806f
 
38812af
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from smolagents import Tool
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer
import numpy as np
import datasets
from typing import List


class SentenceTransformerRetriever:
    """Retriever that uses SentenceTransformer embeddings for semantic search."""

    def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"):
        """Initialize with documents and a SentenceTransformer model.

        Args:
            docs: List of Document objects
            model_name: Name of the SentenceTransformer model to use
        """
        self.docs = docs
        self.model = SentenceTransformer(model_name)

        # Create embeddings for all documents
        self.doc_texts = [doc.page_content for doc in self.docs]
        # Ensure we get numpy arrays for document embeddings
        self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True)

    def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]:
        """Return documents relevant to the query.

        Args:
            query: Query string
            k: Number of documents to return

        Returns:
            List of relevant Document objects
        """
        # Encode the query and ensure we get a numpy array
        query_embedding = self.model.encode(query, convert_to_numpy=True)

        # Calculate similarities
        # Calculate cosine similarity manually to avoid tensor conversion issues
        similarities = []
        for doc_embedding in self.doc_embeddings:
            # Calculate cosine similarity between query and document
            dot_product = np.dot(query_embedding, doc_embedding)
            query_norm = np.linalg.norm(query_embedding)
            doc_norm = np.linalg.norm(doc_embedding)
            similarity = dot_product / (query_norm * doc_norm)
            similarities.append(similarity)

        # Convert to numpy array
        similarities = np.array(similarities)

        # Get the top k most similar documents
        # Sort indices by similarity in descending order and take the top k
        top_k_indices = np.argsort(-similarities)[:k]

        # Return the top k documents
        return [self.docs[i] for i in top_k_indices]


class GuestInfoRetrieverTool(Tool):
    name = "guest_info_retriever"
    description = "Retrieves detailed information about gala guests based on their name or relation using semantic search."
    inputs = {
        "query": {
            "type": "string",
            "description": "The name or relation of the guest you want information about."
        }
    }
    output_type = "string"

    def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"):
        self.is_initialized = False
        self.retriever = SentenceTransformerRetriever(docs, model_name)

    def forward(self, query: str):
        results = self.retriever.get_relevant_documents(query)
        if results:
            return "\n\n".join([doc.page_content for doc in results[:3]])
        else:
            return "No matching guest information found."


def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"):
    """Load the guest dataset and create a retriever tool.

    Args:
        model_name: Name of the SentenceTransformer model to use

    Returns:
        GuestInfoRetrieverTool: A tool for retrieving guest information
    """
    # Load the dataset
    guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

    # Convert dataset entries into Document objects
    docs = [
        Document(
            page_content="\n".join([
                f"Name: {guest['name']}",
                f"Relation: {guest['relation']}",
                f"Description: {guest['description']}",
                f"Email: {guest['email']}"
            ]),
            metadata={"name": guest["name"]}
        )
        for guest in guest_dataset
    ]

    # Return the tool with the specified model
    return GuestInfoRetrieverTool(docs, model_name=model_name)