Source code for neo4j_genai.generation.graphrag

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

from pydantic import ValidationError

from neo4j_genai.exceptions import (
    RagInitializationError,
    SearchValidationError,
)
from neo4j_genai.generation.prompts import RagTemplate
from neo4j_genai.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_genai.llm import LLMInterface
from neo4j_genai.retrievers.base import Retriever
from neo4j_genai.types import RetrieverResult

logger = logging.getLogger(__name__)


[docs] class GraphRAG: """Performs a GraphRAG search using a specific retriever and LLM. Example: .. code-block:: python import neo4j from neo4j_genai.retrievers import VectorRetriever from neo4j_genai.llm.openai_llm import OpenAILLM from neo4j_genai.generation import GraphRAG driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) retriever = VectorRetriever(driver, "vector-index-name", custom_embedder) llm = OpenAILLM() graph_rag = GraphRAG(retriever, llm) graph_rag.search(query_text="Find me a book about Fremen") Args: retriever (Retriever): The retriever used to find relevant context to pass to the LLM. llm (LLMInterface): The LLM used to generate the answer. prompt_template (RagTemplate): The prompt template that will be formatted with context and user question and passed to the LLM. Raises: RagInitializationError: If validation of the input arguments fail. """ def __init__( self, retriever: Retriever, llm: LLMInterface, prompt_template: RagTemplate = RagTemplate(), ): try: validated_data = RagInitModel( retriever=retriever, llm=llm, prompt_template=prompt_template, ) except ValidationError as e: raise RagInitializationError(e.errors()) self.retriever = validated_data.retriever self.llm = validated_data.llm self.prompt_template = validated_data.prompt_template
[docs] def search( self, query_text: str = "", examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool = False, query: Optional[str] = None, ) -> RagResultModel: """This method performs a full RAG search: 1. Retrieval: context retrieval 2. Augmentation: prompt formatting 3. Generation: answer generation with LLM Args: query_text (str): The user question examples (str): Examples added to the LLM prompt. retriever_config (Optional[dict]): Parameters passed to the retriever search method; e.g.: top_k return_context (bool): Whether to append the retriever result to the final result (default: False) query (Optional[str]): The user question. Will be deprecated in favor of query_text. Returns: RagResultModel: The LLM-generated answer """ try: if query is not None: if query_text: warnings.warn( "Both 'query' and 'query_text' are provided, 'query_text' will be used.", DeprecationWarning, stacklevel=2, ) elif isinstance(query, str): warnings.warn( "'query' is deprecated and will be removed in a future version, please use 'query_text' instead.", DeprecationWarning, stacklevel=2, ) query_text = query validated_data = RagSearchModel( query_text=query_text, examples=examples, retriever_config=retriever_config or {}, return_context=return_context, ) except ValidationError as e: raise SearchValidationError(e.errors()) query_text = validated_data.query_text retriever_result: RetrieverResult = self.retriever.search( query_text=query_text, **validated_data.retriever_config ) context = "\n".join(item.content for item in retriever_result.items) prompt = self.prompt_template.format( query_text=query_text, context=context, examples=validated_data.examples ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") answer = self.llm.invoke(prompt) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result return RagResultModel(**result)