Source code for neo4j_graphrag.generation.types

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

from typing import Any, Union

from pydantic import BaseModel, ConfigDict, field_validator

from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult


class RagInitModel(BaseModel):
    retriever: Retriever
    llm: Any
    prompt_template: RagTemplate

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @field_validator("llm")
    def check_llm(cls, value: Any) -> Any:
        invoke = getattr(value, "invoke", None)
        if invoke and callable(invoke):
            return value
        raise ValueError("llm must be callable")


class RagSearchModel(BaseModel):
    query_text: str
    examples: str = ""
    retriever_config: dict[str, Any] = {}
    return_context: bool = False


[docs] class RagResultModel(BaseModel): answer: str retriever_result: Union[RetrieverResult, None] = None model_config = ConfigDict(arbitrary_types_allowed=True)