Use with LangChain

How to integrate neo4j-agent-memory with LangChain to build memory-enabled chains and agents with a persistent context graph.

Overview

LangChain is a popular framework for building LLM applications. Integrating with neo4j-agent-memory provides persistent memory that survives across sessions, enabling agents to build and query a context graph over time.

LangChain + Context Graph Architecture
┌─────────────────────────────────────────────────────┐
│                   LangChain Agent                   │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐ │
│  │    Tools    │  │   Memory    │  │    Chain    │ │
│  │             │  │  Interface  │  │             │ │
│  └──────┬──────┘  └──────┬──────┘  └─────────────┘ │
│         │                │                          │
│         ▼                ▼                          │
│  ┌─────────────────────────────────────────────┐   │
│  │        Neo4jAgentMemory Adapter             │   │
│  │  • Implements BaseChatMessageHistory        │   │
│  │  • Provides entity/preference tools         │   │
│  │  • Callbacks for reasoning traces           │   │
│  └──────────────────────┬──────────────────────┘   │
└─────────────────────────┼───────────────────────────┘
                          │
                          ▼
┌─────────────────────────────────────────────────────┐
│               Neo4j Context Graph                   │
│  Messages ──── Entities ──── Preferences            │
│      │            │              │                  │
│      └────────────┴──────────────┘                  │
│            Reasoning Traces                         │
└─────────────────────────────────────────────────────┘

Prerequisites

  • Python 3.10+

  • neo4j-agent-memory and langchain installed

  • Neo4j database running

  • OpenAI API key (or other LLM provider)

pip install neo4j-agent-memory langchain langchain-openai

Message History Integration

Create Custom Message History

Implement LangChain’s BaseChatMessageHistory interface:

from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from neo4j_agent_memory import MemoryClient


class Neo4jChatMessageHistory(BaseChatMessageHistory):
    """LangChain message history backed by neo4j-agent-memory."""

    def __init__(
        self,
        memory_client: MemoryClient,
        session_id: str,
        user_id: str | None = None,
    ):
        self.memory_client = memory_client
        self.session_id = session_id
        self.user_id = user_id

    @property
    def messages(self) -> List[BaseMessage]:
        """Retrieve all messages for this session."""
        import asyncio

        # Run async method synchronously
        loop = asyncio.get_event_loop()
        raw_messages = loop.run_until_complete(
            self.memory_client.short_term.get_session_messages(
                session_id=self.session_id,
                limit=100,
            )
        )

        result = []
        for msg in raw_messages:
            if msg.role == "user":
                result.append(HumanMessage(content=msg.content))
            elif msg.role == "assistant":
                result.append(AIMessage(content=msg.content))

        return result

    def add_message(self, message: BaseMessage) -> None:
        """Add a message to the session."""
        import asyncio

        role = "user" if isinstance(message, HumanMessage) else "assistant"

        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            self.memory_client.short_term.add_message(
                role=role,
                content=message.content,
                session_id=self.session_id,
                metadata={"user_id": self.user_id} if self.user_id else None,
            )
        )

    def clear(self) -> None:
        """Clear all messages in the session."""
        import asyncio

        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            self.memory_client.short_term.delete_session(
                session_id=self.session_id,
            )
        )


# Async version for async chains
class AsyncNeo4jChatMessageHistory(BaseChatMessageHistory):
    """Async LangChain message history backed by neo4j-agent-memory."""

    def __init__(
        self,
        memory_client: MemoryClient,
        session_id: str,
        user_id: str | None = None,
    ):
        self.memory_client = memory_client
        self.session_id = session_id
        self.user_id = user_id
        self._messages: List[BaseMessage] = []

    async def aget_messages(self) -> List[BaseMessage]:
        """Async retrieve all messages."""
        raw_messages = await self.memory_client.short_term.get_session_messages(
            session_id=self.session_id,
            limit=100,
        )

        result = []
        for msg in raw_messages:
            if msg.role == "user":
                result.append(HumanMessage(content=msg.content))
            elif msg.role == "assistant":
                result.append(AIMessage(content=msg.content))

        return result

    async def aadd_message(self, message: BaseMessage) -> None:
        """Async add a message."""
        role = "user" if isinstance(message, HumanMessage) else "assistant"

        await self.memory_client.short_term.add_message(
            role=role,
            content=message.content,
            session_id=self.session_id,
            metadata={"user_id": self.user_id} if self.user_id else None,
        )

    async def aclear(self) -> None:
        """Async clear messages."""
        await self.memory_client.short_term.delete_session(
            session_id=self.session_id,
        )

    # Sync methods that wrap async
    @property
    def messages(self) -> List[BaseMessage]:
        import asyncio
        return asyncio.get_event_loop().run_until_complete(self.aget_messages())

    def add_message(self, message: BaseMessage) -> None:
        import asyncio
        asyncio.get_event_loop().run_until_complete(self.aadd_message(message))

    def clear(self) -> None:
        import asyncio
        asyncio.get_event_loop().run_until_complete(self.aclear())

Use with Conversation Chain

from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory

# Initialize memory client
memory_client = MemoryClient(
    neo4j_uri="bolt://localhost:7687",
    neo4j_user="neo4j",
    neo4j_password="password",
)

# Create message history
message_history = Neo4jChatMessageHistory(
    memory_client=memory_client,
    session_id="langchain-session-001",
    user_id="CUST-12345",
)

# Create LangChain memory with our history
memory = ConversationBufferMemory(
    chat_memory=message_history,
    return_messages=True,
)

# Create conversation chain
llm = ChatOpenAI(model="gpt-4o")
conversation = ConversationChain(
    llm=llm,
    memory=memory,
    verbose=True,
)

# Use the chain - messages persist to Neo4j
response = conversation.predict(input="I'm looking for running shoes")
print(response)

# Later session can retrieve history
response = conversation.predict(input="What did I ask about earlier?")
print(response)

Custom Tools for Context Graph

Create Context Graph Tools

from langchain.tools import BaseTool, tool
from pydantic import BaseModel, Field
from typing import Optional


class SearchEntitiesInput(BaseModel):
    """Input for entity search."""
    query: str = Field(description="Search query for entities")
    entity_type: Optional[str] = Field(
        default=None,
        description="Entity type filter (e.g., PRODUCT, PERSON, ORGANIZATION)"
    )
    limit: int = Field(default=10, description="Maximum results to return")


class SearchEntitesTool(BaseTool):
    """Tool for searching the context graph."""

    name: str = "search_entities"
    description: str = """
    Search the knowledge graph for entities like products, people, or organizations.
    Use this to find information about things mentioned in past conversations.
    """
    args_schema: type[BaseModel] = SearchEntitiesInput

    memory_client: MemoryClient

    def _run(
        self,
        query: str,
        entity_type: Optional[str] = None,
        limit: int = 10,
    ) -> str:
        import asyncio

        entities = asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.search_entities(
                query=query,
                entity_type=entity_type,
                limit=limit,
            )
        )

        return json.dumps([
            {
                "name": e.name,
                "type": e.type,
                "description": e.description,
            }
            for e in entities
        ])

    async def _arun(
        self,
        query: str,
        entity_type: Optional[str] = None,
        limit: int = 10,
    ) -> str:
        entities = await self.memory_client.long_term.search_entities(
            query=query,
            entity_type=entity_type,
            limit=limit,
        )

        return json.dumps([
            {
                "name": e.name,
                "type": e.type,
                "description": e.description,
            }
            for e in entities
        ])


class GetPreferencesInput(BaseModel):
    """Input for preference retrieval."""
    user_id: str = Field(description="User ID to get preferences for")
    category: Optional[str] = Field(
        default=None,
        description="Category filter (e.g., brand, style, shipping)"
    )


class GetPreferencesTool(BaseTool):
    """Tool for retrieving user preferences."""

    name: str = "get_user_preferences"
    description: str = """
    Get stored preferences for a user. Use this to personalize recommendations
    based on what the user has previously stated they like or dislike.
    """
    args_schema: type[BaseModel] = GetPreferencesInput

    memory_client: MemoryClient

    def _run(self, user_id: str, category: Optional[str] = None) -> str:
        import asyncio

        preferences = asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.get_preferences(
                user_id=user_id,
                category=category,
            )
        )

        return json.dumps([
            {"category": p.category, "preference": p.preference}
            for p in preferences
        ])

    async def _arun(self, user_id: str, category: Optional[str] = None) -> str:
        preferences = await self.memory_client.long_term.get_preferences(
            user_id=user_id,
            category=category,
        )

        return json.dumps([
            {"category": p.category, "preference": p.preference}
            for p in preferences
        ])


# Function-based tools using @tool decorator
@tool
def search_conversation_history(
    query: str,
    session_id: str,
    limit: int = 5,
) -> str:
    """Search past conversation messages for relevant context."""
    import asyncio

    messages = asyncio.get_event_loop().run_until_complete(
        memory_client.short_term.search_messages(
            query=query,
            session_id=session_id,
            limit=limit,
        )
    )

    return json.dumps([
        {"role": m.role, "content": m.content[:200]}
        for m in messages
    ])

LangChain Agent with Memory

Create Memory-Enabled Agent

from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory


def create_memory_agent(
    memory_client: MemoryClient,
    user_id: str,
    session_id: str,
) -> AgentExecutor:
    """Create a LangChain agent with neo4j-agent-memory."""

    # Create tools
    tools = [
        SearchEntitesTool(memory_client=memory_client),
        GetPreferencesTool(memory_client=memory_client),
    ]

    # Create message history
    message_history = Neo4jChatMessageHistory(
        memory_client=memory_client,
        session_id=session_id,
        user_id=user_id,
    )

    # Create memory
    memory = ConversationBufferMemory(
        chat_memory=message_history,
        memory_key="chat_history",
        return_messages=True,
    )

    # Create prompt with memory placeholder
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a helpful assistant with access to a knowledge graph.

Use the available tools to:
1. Search for entities (products, people, organizations) in the knowledge graph
2. Retrieve user preferences to personalize responses

Current user ID: {user_id}

Always use the user's preferences when making recommendations."""),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ])

    # Create LLM
    llm = ChatOpenAI(model="gpt-4o", temperature=0)

    # Create agent
    agent = create_openai_functions_agent(llm, tools, prompt)

    # Create executor
    return AgentExecutor(
        agent=agent,
        tools=tools,
        memory=memory,
        verbose=True,
        handle_parsing_errors=True,
    )


# Usage
memory_client = MemoryClient(
    neo4j_uri="bolt://localhost:7687",
    neo4j_user="neo4j",
    neo4j_password="password",
)

agent = create_memory_agent(
    memory_client=memory_client,
    user_id="CUST-12345",
    session_id="agent-session-001",
)

# Run the agent
result = agent.invoke({
    "input": "Find me some running shoes based on my preferences",
    "user_id": "CUST-12345",
})

print(result["output"])

Ecommerce Example

Complete example for an ecommerce shopping assistant:

import json
from datetime import datetime
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools import BaseTool
from langchain.memory import ConversationBufferMemory
from pydantic import BaseModel, Field
from neo4j_agent_memory import MemoryClient


# --- Tools ---

class ProductSearchInput(BaseModel):
    query: str = Field(description="Product search query")
    category: str | None = Field(default=None, description="Product category")
    brand: str | None = Field(default=None, description="Brand filter")
    max_price: float | None = Field(default=None, description="Maximum price")


class ProductSearchTool(BaseTool):
    name: str = "search_products"
    description: str = """
    Search for products in the catalog. Can filter by category, brand, and price.
    Use customer preferences to choose appropriate filters.
    """
    args_schema: type[BaseModel] = ProductSearchInput
    memory_client: MemoryClient

    def _run(
        self,
        query: str,
        category: str | None = None,
        brand: str | None = None,
        max_price: float | None = None,
    ) -> str:
        import asyncio

        filters = {}
        if brand:
            filters["brand"] = brand
        if max_price:
            filters["price"] = {"$lte": max_price}

        products = asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.search_entities(
                query=query,
                entity_type="PRODUCT",
                property_filter=filters if filters else None,
                limit=10,
            )
        )

        return json.dumps([
            {
                "name": p.name,
                "brand": p.properties.get("brand"),
                "price": p.properties.get("price"),
                "rating": p.properties.get("rating"),
                "description": p.description[:100] if p.description else None,
            }
            for p in products
        ])


class CustomerPreferencesTool(BaseTool):
    name: str = "get_customer_preferences"
    description: str = """
    Get stored preferences for the current customer.
    Returns brand preferences, size, style, and shipping preferences.
    """
    memory_client: MemoryClient
    user_id: str

    def _run(self) -> str:
        import asyncio

        prefs = asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.get_preferences(
                user_id=self.user_id,
            )
        )

        by_category = {}
        for p in prefs:
            if p.category not in by_category:
                by_category[p.category] = []
            by_category[p.category].append(p.preference)

        return json.dumps(by_category)


class SavePreferenceInput(BaseModel):
    preference: str = Field(description="The preference to save")
    category: str = Field(description="Category: brand, style, size, shipping, etc.")


class SavePreferenceTool(BaseTool):
    name: str = "save_preference"
    description: str = """
    Save a new customer preference learned during the conversation.
    Use when customer expresses a like, dislike, or preference.
    """
    args_schema: type[BaseModel] = SavePreferenceInput
    memory_client: MemoryClient
    user_id: str

    def _run(self, preference: str, category: str) -> str:
        import asyncio

        asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.add_preference(
                user_id=self.user_id,
                preference=preference,
                category=category,
                confidence=0.85,
            )
        )

        return f"Saved preference: {category} - {preference}"


class PurchaseHistoryTool(BaseTool):
    name: str = "get_purchase_history"
    description: str = """
    Get customer's recent purchase history from the context graph.
    Useful for understanding what they've bought before.
    """
    memory_client: MemoryClient
    user_id: str

    def _run(self) -> str:
        import asyncio

        purchases = asyncio.get_event_loop().run_until_complete(
            self.memory_client.long_term.execute_query(
                """
                MATCH (c:Customer {id: $user_id})-[p:PURCHASED]->(product:Product)
                RETURN product.name as name, product.brand as brand,
                       p.purchase_date as date, p.price as price
                ORDER BY p.purchase_date DESC
                LIMIT 10
                """,
                parameters={"user_id": self.user_id},
            )
        )

        return json.dumps(purchases)


# --- Agent Factory ---

def create_shopping_agent(
    memory_client: MemoryClient,
    user_id: str,
    session_id: str,
) -> AgentExecutor:
    """Create shopping assistant agent with context graph memory."""

    # Create tools
    tools = [
        ProductSearchTool(memory_client=memory_client),
        CustomerPreferencesTool(memory_client=memory_client, user_id=user_id),
        SavePreferenceTool(memory_client=memory_client, user_id=user_id),
        PurchaseHistoryTool(memory_client=memory_client, user_id=user_id),
    ]

    # Message history
    message_history = Neo4jChatMessageHistory(
        memory_client=memory_client,
        session_id=session_id,
        user_id=user_id,
    )

    memory = ConversationBufferMemory(
        chat_memory=message_history,
        memory_key="chat_history",
        return_messages=True,
    )

    # Prompt
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a helpful shopping assistant for an online retailer.

Your capabilities:
1. Search for products using customer preferences
2. Access and use customer preferences for personalization
3. Save new preferences when customers express them
4. Review purchase history for context

Guidelines:
- Always check customer preferences before making recommendations
- When a customer expresses a preference, save it for future use
- Reference past purchases when relevant
- Be helpful and personalized in your recommendations

Current customer: {user_id}"""),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ])

    llm = ChatOpenAI(model="gpt-4o", temperature=0.7)
    agent = create_openai_functions_agent(llm, tools, prompt)

    return AgentExecutor(
        agent=agent,
        tools=tools,
        memory=memory,
        verbose=True,
    )


# --- Main ---

async def main():
    memory_client = MemoryClient(
        neo4j_uri="bolt://localhost:7687",
        neo4j_user="neo4j",
        neo4j_password="password",
    )

    user_id = "CUST-12345"
    session_id = f"shop-{datetime.now().strftime('%Y%m%d%H%M%S')}"

    agent = create_shopping_agent(memory_client, user_id, session_id)

    # Conversation loop
    print("Shopping Assistant (type 'quit' to exit)\n")

    while True:
        user_input = input("You: ").strip()
        if user_input.lower() == "quit":
            break

        result = agent.invoke({
            "input": user_input,
            "user_id": user_id,
        })

        print(f"\nAssistant: {result['output']}\n")


if __name__ == "__main__":
    import asyncio
    asyncio.run(main())

Callbacks for Reasoning Traces

Track LangChain execution in reasoning traces:

from langchain.callbacks.base import BaseCallbackHandler
from neo4j_agent_memory import MemoryClient


class ReasoningTraceCallback(BaseCallbackHandler):
    """Record LangChain execution as reasoning traces."""

    def __init__(self, memory_client: MemoryClient, user_id: str):
        self.memory_client = memory_client
        self.user_id = user_id
        self.trace_id = None
        self.current_step_id = None

    def on_chain_start(self, serialized, inputs, **kwargs):
        """Start trace when chain begins."""
        import asyncio

        trace = asyncio.get_event_loop().run_until_complete(
            self.memory_client.reasoning.start_trace(
                task=str(inputs.get("input", ""))[:200],
                user_id=self.user_id,
                metadata={"chain_type": serialized.get("name", "unknown")},
            )
        )
        self.trace_id = trace.id

    def on_tool_start(self, serialized, input_str, **kwargs):
        """Record tool invocation."""
        import asyncio

        if not self.trace_id:
            return

        tool_name = serialized.get("name", "unknown")

        step = asyncio.get_event_loop().run_until_complete(
            self.memory_client.reasoning.add_step(
                trace_id=self.trace_id,
                description=f"Using tool: {tool_name}",
                reasoning=f"Input: {input_str[:100]}",
            )
        )
        self.current_step_id = step.id

    def on_tool_end(self, output, **kwargs):
        """Record tool result."""
        import asyncio

        if not self.current_step_id:
            return

        asyncio.get_event_loop().run_until_complete(
            self.memory_client.reasoning.add_tool_call(
                step_id=self.current_step_id,
                tool_name=kwargs.get("name", "unknown"),
                arguments={},
                result={"output": str(output)[:500]},
                success=True,
            )
        )

    def on_chain_end(self, outputs, **kwargs):
        """Complete trace when chain finishes."""
        import asyncio

        if not self.trace_id:
            return

        asyncio.get_event_loop().run_until_complete(
            self.memory_client.reasoning.complete_trace(
                trace_id=self.trace_id,
                outcome="success",
                result={"output": str(outputs)[:500]},
            )
        )

    def on_chain_error(self, error, **kwargs):
        """Record chain failure."""
        import asyncio

        if not self.trace_id:
            return

        asyncio.get_event_loop().run_until_complete(
            self.memory_client.reasoning.complete_trace(
                trace_id=self.trace_id,
                outcome="failure",
                error=str(error),
            )
        )


# Usage
callback = ReasoningTraceCallback(memory_client, user_id)

result = agent.invoke(
    {"input": "Find running shoes", "user_id": user_id},
    config={"callbacks": [callback]},
)

Best Practices

1. Use Session-Based History

Create unique sessions for conversation isolation:

# Good: Unique session per conversation
session_id = f"user-{user_id}-{datetime.now().strftime('%Y%m%d%H%M%S')}"

# Avoid: Shared session for all users
session_id = "global"  # Don't do this

2. Handle Async Properly

Use async versions when running in async context:

# In async context
history = AsyncNeo4jChatMessageHistory(memory_client, session_id)
messages = await history.aget_messages()

# In sync context
history = Neo4jChatMessageHistory(memory_client, session_id)
messages = history.messages

3. Limit Memory Size

Prevent context from growing too large:

from langchain.memory import ConversationBufferWindowMemory

# Keep only last k messages in context
memory = ConversationBufferWindowMemory(
    chat_memory=message_history,
    k=10,  # Last 10 messages
    return_messages=True,
)

4. Add Error Handling

Gracefully handle memory failures:

class RobustNeo4jChatMessageHistory(BaseChatMessageHistory):
    def add_message(self, message: BaseMessage) -> None:
        try:
            # Attempt to save
            asyncio.get_event_loop().run_until_complete(
                self.memory_client.short_term.add_message(...)
            )
        except Exception as e:
            # Log but don't fail the chain
            logger.warning(f"Failed to save message to memory: {e}")