Source code for neo4j_graphrag.message_history

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import threading
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import neo4j
from pydantic import PositiveInt

from neo4j_graphrag.types import (
    LLMMessage,
    Neo4jDriverModel,
    Neo4jMessageHistoryModel,
)

CREATE_SESSION_NODE_QUERY = "MERGE (s:`{node_label}` {{id:$session_id}})"

DELETE_SESSION_AND_MESSAGES_QUERY = (
    "MATCH (s:`{node_label}`) "
    "WHERE s.id = $session_id "
    "OPTIONAL MATCH p=(s)-[:LAST_MESSAGE]->(:Message)<-[:NEXT*0..]-(:Message) "
    "WITH CASE WHEN p IS NULL THEN [s] ELSE nodes(p) END AS nodes "
    "UNWIND nodes AS node "
    "DETACH DELETE node;"
)

DELETE_MESSAGES_QUERY = (
    "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message:Message) "
    "WHERE s.id = $session_id "
    "MATCH p=(last_message)<-[:NEXT*0..]-(:Message) "
    "UNWIND nodes(p) as node "
    "DETACH DELETE node;"
)

GET_MESSAGES_QUERY = (
    "MATCH (s:`{node_label}`)-[:LAST_MESSAGE]->(last_message) "
    "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
    "{window}]-() WITH p, length(p) AS length "
    "ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
    "RETURN {{data:{{content: node.content}}, role:node.role}} AS result"
)

ADD_MESSAGE_QUERY = (
    "MATCH (s:`{node_label}`) WHERE s.id = $session_id "
    "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
    "CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
    "SET new += {{role:$role, content:$content}} "
    "WITH new, lm, last_message WHERE last_message IS NOT NULL "
    "CREATE (last_message)-[:NEXT]->(new) "
    "DELETE lm"
)


class MessageHistory(ABC):
    """Abstract base class for message history storage."""

    @property
    @abstractmethod
    def messages(self) -> List[LLMMessage]: ...

    @abstractmethod
    def add_message(self, message: LLMMessage) -> None: ...

    def add_messages(self, messages: List[LLMMessage]) -> None:
        for message in messages:
            self.add_message(message)

    @abstractmethod
    def clear(self) -> None: ...


[docs] class InMemoryMessageHistory(MessageHistory): """Message history stored in memory Example: .. code-block:: python from neo4j_graphrag.message_history import InMemoryMessageHistory from neo4j_graphrag.types import LLMMessage history = InMemoryMessageHistory() message = LLMMessage(role="user", content="Hello!") history.add_message(message) Args: messages (Optional[List[LLMMessage]]): List of messages to initialize the history with. Defaults to None. """ def __init__(self, messages: Optional[List[LLMMessage]] = None) -> None: self._lock = threading.Lock() self._messages = messages or [] @property def messages(self) -> List[LLMMessage]: with self._lock: return self._messages.copy() def add_message(self, message: LLMMessage) -> None: with self._lock: self._messages.append(message) def add_messages(self, messages: List[LLMMessage]) -> None: with self._lock: self._messages.extend(messages) def clear(self) -> None: with self._lock: self._messages = []
[docs] class Neo4jMessageHistory(MessageHistory): """Message history stored in a Neo4j database Example: .. code-block:: python import neo4j from neo4j_graphrag.message_history import Neo4jMessageHistory from neo4j_graphrag.types import LLMMessage driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) history = Neo4jMessageHistory( session_id="123", driver=driver, window=10 ) message = LLMMessage(role="user", content="Hello!") history.add_message(message) Args: session_id (Union[str, int]): Unique identifier for the chat session. driver (neo4j.Driver): Neo4j driver instance. node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session". window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages. """ def __init__( self, session_id: Union[str, int], driver: neo4j.Driver, window: Optional[PositiveInt] = None, ) -> None: validated_data = Neo4jMessageHistoryModel( session_id=session_id, driver_model=Neo4jDriverModel(driver=driver), window=window, ) self._driver = validated_data.driver_model.driver self._session_id = validated_data.session_id self._window = ( "" if validated_data.window is None else validated_data.window - 1 ) # Create session node self._driver.execute_query( query_=CREATE_SESSION_NODE_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, ) @property def messages(self) -> List[LLMMessage]: result = self._driver.execute_query( query_=GET_MESSAGES_QUERY.format(node_label="Session", window=self._window), parameters_={"session_id": self._session_id}, ) messages = [ LLMMessage( content=el["result"]["data"]["content"], role=el["result"]["role"], ) for el in result.records ] return messages @messages.setter def messages(self, messages: List[LLMMessage]) -> None: raise NotImplementedError( "Direct assignment to 'messages' is not allowed." " Use the 'add_messages' instead." ) def add_message(self, message: LLMMessage) -> None: """Add a message to the message history. Args: message (LLMMessage): The message to add. """ self._driver.execute_query( query_=ADD_MESSAGE_QUERY.format(node_label="Session"), parameters_={ "role": message["role"], "content": message["content"], "session_id": self._session_id, }, ) def clear(self, delete_session_node: bool = False) -> None: """Clear the message history. Args: delete_session_node (bool): Whether to delete the session node. Defaults to False. """ if delete_session_node: self._driver.execute_query( query_=DELETE_SESSION_AND_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, ) else: self._driver.execute_query( query_=DELETE_MESSAGES_QUERY.format(node_label="Session"), parameters_={"session_id": self._session_id}, )