# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Optional
from neo4j_graphrag.exceptions import (
PromptMissingInputError,
PromptMissingPlaceholderError,
)
[docs]
class PromptTemplate:
"""This class is used to generate a parameterized prompt. It is defined
from a string (the template) using the Python format syntax (parameters
between curly braces `{}`) and a list of required inputs.
Before sending the instructions to an LLM, call the `format` method that will
replace parameters with the provided values. If any of the expected inputs is
missing, a `PromptMissingInputError` is raised.
"""
DEFAULT_TEMPLATE: str = ""
EXPECTED_INPUTS: list[str] = list()
def __init__(
self,
template: Optional[str] = None,
expected_inputs: Optional[list[str]] = None,
) -> None:
self.template = template or self.DEFAULT_TEMPLATE
self.expected_inputs = expected_inputs or self.EXPECTED_INPUTS
for e in self.expected_inputs:
if f"{{{e}}}" not in self.template:
raise PromptMissingPlaceholderError(
f"`template` is missing placeholder {e}"
)
def _format(self, **kwargs: Any) -> str:
for e in self.EXPECTED_INPUTS:
if e not in kwargs:
raise PromptMissingInputError(f"Missing input '{e}'")
return self.template.format(**kwargs)
[docs]
class RagTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """Answer the user question using the following context
Context:
{context}
Examples:
{examples}
Question:
{query_text}
Answer:
"""
EXPECTED_INPUTS = ["context", "query_text", "examples"]
def format(self, query_text: str, context: str, examples: str) -> str:
return super().format(query_text=query_text, context=context, examples=examples)
[docs]
class Text2CypherTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """
Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.
Schema:
{schema}
Examples (optional):
{examples}
Input:
{query_text}
Do not use any properties or relationships not included in the schema.
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.
Cypher query:
"""
EXPECTED_INPUTS = ["query_text"]
def format(
self,
schema: Optional[str] = None,
examples: Optional[str] = None,
query_text: str = "",
query: Optional[str] = None,
**kwargs: Any,
) -> str:
if query is not None:
if query_text:
warnings.warn(
"Both 'query' and 'query_text' are provided, 'query_text' will be used.",
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query
return super().format(
query_text=query_text, schema=schema, examples=examples, **kwargs
)