Use with LangChain
How to integrate neo4j-agent-memory with LangChain to build memory-enabled chains and agents with a persistent context graph.
Overview
LangChain is a popular framework for building LLM applications. Integrating with neo4j-agent-memory provides persistent memory that survives across sessions, enabling agents to build and query a context graph over time.
| LangChain + Context Graph Architecture |
|---|
|
Prerequisites
-
Python 3.10+
-
neo4j-agent-memoryandlangchaininstalled -
Neo4j database running
-
OpenAI API key (or other LLM provider)
pip install neo4j-agent-memory langchain langchain-openai
Message History Integration
Create Custom Message History
Implement LangChain’s BaseChatMessageHistory interface:
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from neo4j_agent_memory import MemoryClient
class Neo4jChatMessageHistory(BaseChatMessageHistory):
"""LangChain message history backed by neo4j-agent-memory."""
def __init__(
self,
memory_client: MemoryClient,
session_id: str,
user_id: str | None = None,
):
self.memory_client = memory_client
self.session_id = session_id
self.user_id = user_id
@property
def messages(self) -> List[BaseMessage]:
"""Retrieve all messages for this session."""
import asyncio
# Run async method synchronously
loop = asyncio.get_event_loop()
raw_messages = loop.run_until_complete(
self.memory_client.short_term.get_session_messages(
session_id=self.session_id,
limit=100,
)
)
result = []
for msg in raw_messages:
if msg.role == "user":
result.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
result.append(AIMessage(content=msg.content))
return result
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the session."""
import asyncio
role = "user" if isinstance(message, HumanMessage) else "assistant"
loop = asyncio.get_event_loop()
loop.run_until_complete(
self.memory_client.short_term.add_message(
role=role,
content=message.content,
session_id=self.session_id,
metadata={"user_id": self.user_id} if self.user_id else None,
)
)
def clear(self) -> None:
"""Clear all messages in the session."""
import asyncio
loop = asyncio.get_event_loop()
loop.run_until_complete(
self.memory_client.short_term.delete_session(
session_id=self.session_id,
)
)
# Async version for async chains
class AsyncNeo4jChatMessageHistory(BaseChatMessageHistory):
"""Async LangChain message history backed by neo4j-agent-memory."""
def __init__(
self,
memory_client: MemoryClient,
session_id: str,
user_id: str | None = None,
):
self.memory_client = memory_client
self.session_id = session_id
self.user_id = user_id
self._messages: List[BaseMessage] = []
async def aget_messages(self) -> List[BaseMessage]:
"""Async retrieve all messages."""
raw_messages = await self.memory_client.short_term.get_session_messages(
session_id=self.session_id,
limit=100,
)
result = []
for msg in raw_messages:
if msg.role == "user":
result.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
result.append(AIMessage(content=msg.content))
return result
async def aadd_message(self, message: BaseMessage) -> None:
"""Async add a message."""
role = "user" if isinstance(message, HumanMessage) else "assistant"
await self.memory_client.short_term.add_message(
role=role,
content=message.content,
session_id=self.session_id,
metadata={"user_id": self.user_id} if self.user_id else None,
)
async def aclear(self) -> None:
"""Async clear messages."""
await self.memory_client.short_term.delete_session(
session_id=self.session_id,
)
# Sync methods that wrap async
@property
def messages(self) -> List[BaseMessage]:
import asyncio
return asyncio.get_event_loop().run_until_complete(self.aget_messages())
def add_message(self, message: BaseMessage) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self.aadd_message(message))
def clear(self) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self.aclear())
Use with Conversation Chain
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
# Initialize memory client
memory_client = MemoryClient(
neo4j_uri="bolt://localhost:7687",
neo4j_user="neo4j",
neo4j_password="password",
)
# Create message history
message_history = Neo4jChatMessageHistory(
memory_client=memory_client,
session_id="langchain-session-001",
user_id="CUST-12345",
)
# Create LangChain memory with our history
memory = ConversationBufferMemory(
chat_memory=message_history,
return_messages=True,
)
# Create conversation chain
llm = ChatOpenAI(model="gpt-4o")
conversation = ConversationChain(
llm=llm,
memory=memory,
verbose=True,
)
# Use the chain - messages persist to Neo4j
response = conversation.predict(input="I'm looking for running shoes")
print(response)
# Later session can retrieve history
response = conversation.predict(input="What did I ask about earlier?")
print(response)
Custom Tools for Context Graph
Create Context Graph Tools
from langchain.tools import BaseTool, tool
from pydantic import BaseModel, Field
from typing import Optional
class SearchEntitiesInput(BaseModel):
"""Input for entity search."""
query: str = Field(description="Search query for entities")
entity_type: Optional[str] = Field(
default=None,
description="Entity type filter (e.g., PRODUCT, PERSON, ORGANIZATION)"
)
limit: int = Field(default=10, description="Maximum results to return")
class SearchEntitesTool(BaseTool):
"""Tool for searching the context graph."""
name: str = "search_entities"
description: str = """
Search the knowledge graph for entities like products, people, or organizations.
Use this to find information about things mentioned in past conversations.
"""
args_schema: type[BaseModel] = SearchEntitiesInput
memory_client: MemoryClient
def _run(
self,
query: str,
entity_type: Optional[str] = None,
limit: int = 10,
) -> str:
import asyncio
entities = asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.search_entities(
query=query,
entity_type=entity_type,
limit=limit,
)
)
return json.dumps([
{
"name": e.name,
"type": e.type,
"description": e.description,
}
for e in entities
])
async def _arun(
self,
query: str,
entity_type: Optional[str] = None,
limit: int = 10,
) -> str:
entities = await self.memory_client.long_term.search_entities(
query=query,
entity_type=entity_type,
limit=limit,
)
return json.dumps([
{
"name": e.name,
"type": e.type,
"description": e.description,
}
for e in entities
])
class GetPreferencesInput(BaseModel):
"""Input for preference retrieval."""
user_id: str = Field(description="User ID to get preferences for")
category: Optional[str] = Field(
default=None,
description="Category filter (e.g., brand, style, shipping)"
)
class GetPreferencesTool(BaseTool):
"""Tool for retrieving user preferences."""
name: str = "get_user_preferences"
description: str = """
Get stored preferences for a user. Use this to personalize recommendations
based on what the user has previously stated they like or dislike.
"""
args_schema: type[BaseModel] = GetPreferencesInput
memory_client: MemoryClient
def _run(self, user_id: str, category: Optional[str] = None) -> str:
import asyncio
preferences = asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.get_preferences(
user_id=user_id,
category=category,
)
)
return json.dumps([
{"category": p.category, "preference": p.preference}
for p in preferences
])
async def _arun(self, user_id: str, category: Optional[str] = None) -> str:
preferences = await self.memory_client.long_term.get_preferences(
user_id=user_id,
category=category,
)
return json.dumps([
{"category": p.category, "preference": p.preference}
for p in preferences
])
# Function-based tools using @tool decorator
@tool
def search_conversation_history(
query: str,
session_id: str,
limit: int = 5,
) -> str:
"""Search past conversation messages for relevant context."""
import asyncio
messages = asyncio.get_event_loop().run_until_complete(
memory_client.short_term.search_messages(
query=query,
session_id=session_id,
limit=limit,
)
)
return json.dumps([
{"role": m.role, "content": m.content[:200]}
for m in messages
])
LangChain Agent with Memory
Create Memory-Enabled Agent
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
def create_memory_agent(
memory_client: MemoryClient,
user_id: str,
session_id: str,
) -> AgentExecutor:
"""Create a LangChain agent with neo4j-agent-memory."""
# Create tools
tools = [
SearchEntitesTool(memory_client=memory_client),
GetPreferencesTool(memory_client=memory_client),
]
# Create message history
message_history = Neo4jChatMessageHistory(
memory_client=memory_client,
session_id=session_id,
user_id=user_id,
)
# Create memory
memory = ConversationBufferMemory(
chat_memory=message_history,
memory_key="chat_history",
return_messages=True,
)
# Create prompt with memory placeholder
prompt = ChatPromptTemplate.from_messages([
("system", """You are a helpful assistant with access to a knowledge graph.
Use the available tools to:
1. Search for entities (products, people, organizations) in the knowledge graph
2. Retrieve user preferences to personalize responses
Current user ID: {user_id}
Always use the user's preferences when making recommendations."""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
# Create LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# Create agent
agent = create_openai_functions_agent(llm, tools, prompt)
# Create executor
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=True,
handle_parsing_errors=True,
)
# Usage
memory_client = MemoryClient(
neo4j_uri="bolt://localhost:7687",
neo4j_user="neo4j",
neo4j_password="password",
)
agent = create_memory_agent(
memory_client=memory_client,
user_id="CUST-12345",
session_id="agent-session-001",
)
# Run the agent
result = agent.invoke({
"input": "Find me some running shoes based on my preferences",
"user_id": "CUST-12345",
})
print(result["output"])
Ecommerce Example
Complete example for an ecommerce shopping assistant:
import json
from datetime import datetime
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools import BaseTool
from langchain.memory import ConversationBufferMemory
from pydantic import BaseModel, Field
from neo4j_agent_memory import MemoryClient
# --- Tools ---
class ProductSearchInput(BaseModel):
query: str = Field(description="Product search query")
category: str | None = Field(default=None, description="Product category")
brand: str | None = Field(default=None, description="Brand filter")
max_price: float | None = Field(default=None, description="Maximum price")
class ProductSearchTool(BaseTool):
name: str = "search_products"
description: str = """
Search for products in the catalog. Can filter by category, brand, and price.
Use customer preferences to choose appropriate filters.
"""
args_schema: type[BaseModel] = ProductSearchInput
memory_client: MemoryClient
def _run(
self,
query: str,
category: str | None = None,
brand: str | None = None,
max_price: float | None = None,
) -> str:
import asyncio
filters = {}
if brand:
filters["brand"] = brand
if max_price:
filters["price"] = {"$lte": max_price}
products = asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.search_entities(
query=query,
entity_type="PRODUCT",
property_filter=filters if filters else None,
limit=10,
)
)
return json.dumps([
{
"name": p.name,
"brand": p.properties.get("brand"),
"price": p.properties.get("price"),
"rating": p.properties.get("rating"),
"description": p.description[:100] if p.description else None,
}
for p in products
])
class CustomerPreferencesTool(BaseTool):
name: str = "get_customer_preferences"
description: str = """
Get stored preferences for the current customer.
Returns brand preferences, size, style, and shipping preferences.
"""
memory_client: MemoryClient
user_id: str
def _run(self) -> str:
import asyncio
prefs = asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.get_preferences(
user_id=self.user_id,
)
)
by_category = {}
for p in prefs:
if p.category not in by_category:
by_category[p.category] = []
by_category[p.category].append(p.preference)
return json.dumps(by_category)
class SavePreferenceInput(BaseModel):
preference: str = Field(description="The preference to save")
category: str = Field(description="Category: brand, style, size, shipping, etc.")
class SavePreferenceTool(BaseTool):
name: str = "save_preference"
description: str = """
Save a new customer preference learned during the conversation.
Use when customer expresses a like, dislike, or preference.
"""
args_schema: type[BaseModel] = SavePreferenceInput
memory_client: MemoryClient
user_id: str
def _run(self, preference: str, category: str) -> str:
import asyncio
asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.add_preference(
user_id=self.user_id,
preference=preference,
category=category,
confidence=0.85,
)
)
return f"Saved preference: {category} - {preference}"
class PurchaseHistoryTool(BaseTool):
name: str = "get_purchase_history"
description: str = """
Get customer's recent purchase history from the context graph.
Useful for understanding what they've bought before.
"""
memory_client: MemoryClient
user_id: str
def _run(self) -> str:
import asyncio
purchases = asyncio.get_event_loop().run_until_complete(
self.memory_client.long_term.execute_query(
"""
MATCH (c:Customer {id: $user_id})-[p:PURCHASED]->(product:Product)
RETURN product.name as name, product.brand as brand,
p.purchase_date as date, p.price as price
ORDER BY p.purchase_date DESC
LIMIT 10
""",
parameters={"user_id": self.user_id},
)
)
return json.dumps(purchases)
# --- Agent Factory ---
def create_shopping_agent(
memory_client: MemoryClient,
user_id: str,
session_id: str,
) -> AgentExecutor:
"""Create shopping assistant agent with context graph memory."""
# Create tools
tools = [
ProductSearchTool(memory_client=memory_client),
CustomerPreferencesTool(memory_client=memory_client, user_id=user_id),
SavePreferenceTool(memory_client=memory_client, user_id=user_id),
PurchaseHistoryTool(memory_client=memory_client, user_id=user_id),
]
# Message history
message_history = Neo4jChatMessageHistory(
memory_client=memory_client,
session_id=session_id,
user_id=user_id,
)
memory = ConversationBufferMemory(
chat_memory=message_history,
memory_key="chat_history",
return_messages=True,
)
# Prompt
prompt = ChatPromptTemplate.from_messages([
("system", """You are a helpful shopping assistant for an online retailer.
Your capabilities:
1. Search for products using customer preferences
2. Access and use customer preferences for personalization
3. Save new preferences when customers express them
4. Review purchase history for context
Guidelines:
- Always check customer preferences before making recommendations
- When a customer expresses a preference, save it for future use
- Reference past purchases when relevant
- Be helpful and personalized in your recommendations
Current customer: {user_id}"""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
llm = ChatOpenAI(model="gpt-4o", temperature=0.7)
agent = create_openai_functions_agent(llm, tools, prompt)
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=True,
)
# --- Main ---
async def main():
memory_client = MemoryClient(
neo4j_uri="bolt://localhost:7687",
neo4j_user="neo4j",
neo4j_password="password",
)
user_id = "CUST-12345"
session_id = f"shop-{datetime.now().strftime('%Y%m%d%H%M%S')}"
agent = create_shopping_agent(memory_client, user_id, session_id)
# Conversation loop
print("Shopping Assistant (type 'quit' to exit)\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
break
result = agent.invoke({
"input": user_input,
"user_id": user_id,
})
print(f"\nAssistant: {result['output']}\n")
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Callbacks for Reasoning Traces
Track LangChain execution in reasoning traces:
from langchain.callbacks.base import BaseCallbackHandler
from neo4j_agent_memory import MemoryClient
class ReasoningTraceCallback(BaseCallbackHandler):
"""Record LangChain execution as reasoning traces."""
def __init__(self, memory_client: MemoryClient, user_id: str):
self.memory_client = memory_client
self.user_id = user_id
self.trace_id = None
self.current_step_id = None
def on_chain_start(self, serialized, inputs, **kwargs):
"""Start trace when chain begins."""
import asyncio
trace = asyncio.get_event_loop().run_until_complete(
self.memory_client.reasoning.start_trace(
task=str(inputs.get("input", ""))[:200],
user_id=self.user_id,
metadata={"chain_type": serialized.get("name", "unknown")},
)
)
self.trace_id = trace.id
def on_tool_start(self, serialized, input_str, **kwargs):
"""Record tool invocation."""
import asyncio
if not self.trace_id:
return
tool_name = serialized.get("name", "unknown")
step = asyncio.get_event_loop().run_until_complete(
self.memory_client.reasoning.add_step(
trace_id=self.trace_id,
description=f"Using tool: {tool_name}",
reasoning=f"Input: {input_str[:100]}",
)
)
self.current_step_id = step.id
def on_tool_end(self, output, **kwargs):
"""Record tool result."""
import asyncio
if not self.current_step_id:
return
asyncio.get_event_loop().run_until_complete(
self.memory_client.reasoning.add_tool_call(
step_id=self.current_step_id,
tool_name=kwargs.get("name", "unknown"),
arguments={},
result={"output": str(output)[:500]},
success=True,
)
)
def on_chain_end(self, outputs, **kwargs):
"""Complete trace when chain finishes."""
import asyncio
if not self.trace_id:
return
asyncio.get_event_loop().run_until_complete(
self.memory_client.reasoning.complete_trace(
trace_id=self.trace_id,
outcome="success",
result={"output": str(outputs)[:500]},
)
)
def on_chain_error(self, error, **kwargs):
"""Record chain failure."""
import asyncio
if not self.trace_id:
return
asyncio.get_event_loop().run_until_complete(
self.memory_client.reasoning.complete_trace(
trace_id=self.trace_id,
outcome="failure",
error=str(error),
)
)
# Usage
callback = ReasoningTraceCallback(memory_client, user_id)
result = agent.invoke(
{"input": "Find running shoes", "user_id": user_id},
config={"callbacks": [callback]},
)
Best Practices
1. Use Session-Based History
Create unique sessions for conversation isolation:
# Good: Unique session per conversation
session_id = f"user-{user_id}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Avoid: Shared session for all users
session_id = "global" # Don't do this
2. Handle Async Properly
Use async versions when running in async context:
# In async context
history = AsyncNeo4jChatMessageHistory(memory_client, session_id)
messages = await history.aget_messages()
# In sync context
history = Neo4jChatMessageHistory(memory_client, session_id)
messages = history.messages
3. Limit Memory Size
Prevent context from growing too large:
from langchain.memory import ConversationBufferWindowMemory
# Keep only last k messages in context
memory = ConversationBufferWindowMemory(
chat_memory=message_history,
k=10, # Last 10 messages
return_messages=True,
)
4. Add Error Handling
Gracefully handle memory failures:
class RobustNeo4jChatMessageHistory(BaseChatMessageHistory):
def add_message(self, message: BaseMessage) -> None:
try:
# Attempt to save
asyncio.get_event_loop().run_until_complete(
self.memory_client.short_term.add_message(...)
)
except Exception as e:
# Log but don't fail the chain
logger.warning(f"Failed to save message to memory: {e}")