Use with PydanticAI
How to integrate neo4j-agent-memory with PydanticAI agents to build memory-enabled applications with a persistent context graph.
Overview
PydanticAI is a Python agent framework that provides type-safe tools, dependency injection, and structured outputs. Integrating with neo4j-agent-memory adds persistent memory capabilities, enabling agents to remember past interactions and build knowledge over time.
| PydanticAI + Context Graph Architecture |
|---|
|
Prerequisites
-
Python 3.10+
-
neo4j-agent-memoryandpydantic-aiinstalled -
Neo4j database running
-
OpenAI API key (or other LLM provider)
pip install neo4j-agent-memory pydantic-ai
Basic Integration
Define Agent Dependencies
Create a dependency class that includes the memory client:
from dataclasses import dataclass
from neo4j_agent_memory import MemoryClient
@dataclass
class AgentDeps:
"""Dependencies available to all agent tools and prompts."""
memory_client: MemoryClient
user_id: str
session_id: str
current_query: str | None = None
Create Memory-Enabled Agent
from pydantic_ai import Agent, RunContext
agent = Agent(
"openai:gpt-4o",
deps_type=AgentDeps,
system_prompt="""
You are a helpful shopping assistant. Use the available tools to:
1. Search for products matching customer requests
2. Retrieve customer preferences to personalize recommendations
3. Access conversation history for context
Always consider the customer's stated preferences when making recommendations.
""",
)
Add Memory Tools
import json
@agent.tool
async def search_messages(
ctx: RunContext[AgentDeps],
query: str,
limit: int = 5,
) -> str:
"""Search past conversation messages for relevant context."""
messages = await ctx.deps.memory_client.short_term.search_messages(
query=query,
session_id=ctx.deps.session_id,
limit=limit,
)
return json.dumps([
{"role": m.role, "content": m.content[:200]}
for m in messages
])
@agent.tool
async def get_preferences(
ctx: RunContext[AgentDeps],
category: str | None = None,
) -> str:
"""Get customer preferences from memory."""
preferences = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.user_id,
category=category,
)
return json.dumps([
{"category": p.category, "preference": p.preference}
for p in preferences
])
@agent.tool
async def search_entities(
ctx: RunContext[AgentDeps],
query: str,
entity_type: str | None = None,
limit: int = 10,
) -> str:
"""Search the context graph for relevant entities."""
entities = await ctx.deps.memory_client.long_term.search_entities(
query=query,
entity_type=entity_type,
limit=limit,
)
return json.dumps([
{
"name": e.name,
"type": e.type,
"description": e.description,
}
for e in entities
])
Dynamic System Prompt with Memory
Inject memory context into the system prompt:
@agent.system_prompt
async def add_memory_context(ctx: RunContext[AgentDeps]) -> str:
"""Dynamically add memory context to system prompt."""
parts = []
# Add user preferences
preferences = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.user_id,
limit=10,
)
if preferences:
pref_list = [f"- {p.category}: {p.preference}" for p in preferences]
parts.append(f"## Customer Preferences\n{chr(10).join(pref_list)}")
# Add relevant past interactions
if ctx.deps.current_query:
traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
task=ctx.deps.current_query,
user_id=ctx.deps.user_id,
limit=2,
success_only=True,
)
if traces:
trace_items = [
f"- Query: {t.task[:50]}..., Outcome: {t.result.get('summary', 'success')}"
for t in traces
]
parts.append(f"## Relevant Past Interactions\n{chr(10).join(trace_items)}")
return "\n\n".join(parts) if parts else ""
Run the Agent
async def main():
# Initialize memory client
memory_client = MemoryClient(
neo4j_uri="bolt://localhost:7687",
neo4j_user="neo4j",
neo4j_password="password",
)
user_id = "CUST-12345"
session_id = f"session-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
user_query = "I need new running shoes for marathon training"
# Create dependencies
deps = AgentDeps(
memory_client=memory_client,
user_id=user_id,
session_id=session_id,
current_query=user_query,
)
# Store user message
await memory_client.short_term.add_message(
role="user",
content=user_query,
session_id=session_id,
metadata={"user_id": user_id},
)
# Run agent
result = await agent.run(user_query, deps=deps)
# Store assistant response
await memory_client.short_term.add_message(
role="assistant",
content=result.data,
session_id=session_id,
metadata={"user_id": user_id},
)
print(result.data)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Ecommerce Shopping Assistant
A complete example of a memory-enabled shopping assistant:
from dataclasses import dataclass, field
from datetime import datetime
import json
from pydantic_ai import Agent, RunContext
from pydantic import BaseModel
from neo4j_agent_memory import MemoryClient
# --- Response Models ---
class ProductRecommendation(BaseModel):
"""Structured product recommendation."""
products: list[dict]
reasoning: str
personalization_used: list[str]
# --- Dependencies ---
@dataclass
class ShoppingDeps:
memory_client: MemoryClient
user_id: str
session_id: str
current_query: str = ""
# --- Agent Definition ---
shopping_agent = Agent(
"openai:gpt-4o",
deps_type=ShoppingDeps,
result_type=ProductRecommendation,
system_prompt="""
You are a personal shopping assistant for an online retailer.
Your goals:
1. Understand what the customer is looking for
2. Use their preferences and history to personalize recommendations
3. Provide helpful, relevant product suggestions
4. Remember important details for future interactions
Always explain why you're recommending specific products based on
the customer's preferences and past behavior.
""",
)
# --- Memory Tools ---
@shopping_agent.tool
async def search_products(
ctx: RunContext[ShoppingDeps],
query: str,
category: str | None = None,
brand: str | None = None,
max_price: float | None = None,
limit: int = 10,
) -> str:
"""Search product catalog. Use customer preferences to filter when relevant."""
# Build property filter
filters = {}
if brand:
filters["brand"] = brand
if max_price:
filters["price"] = {"$lte": max_price}
products = await ctx.deps.memory_client.long_term.search_entities(
query=query,
entity_type="PRODUCT",
property_filter=filters if filters else None,
limit=limit,
)
return json.dumps([
{
"name": p.name,
"brand": p.properties.get("brand"),
"price": p.properties.get("price"),
"rating": p.properties.get("rating"),
"description": p.description,
}
for p in products
])
@shopping_agent.tool
async def get_customer_preferences(
ctx: RunContext[ShoppingDeps],
) -> str:
"""Get all stored preferences for the current customer."""
preferences = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.user_id,
)
# Group by category
by_category = {}
for pref in preferences:
if pref.category not in by_category:
by_category[pref.category] = []
by_category[pref.category].append(pref.preference)
return json.dumps(by_category)
@shopping_agent.tool
async def get_purchase_history(
ctx: RunContext[ShoppingDeps],
limit: int = 10,
) -> str:
"""Get customer's recent purchase history from the context graph."""
# Query the context graph for purchases
purchases = await ctx.deps.memory_client.long_term.execute_query(
"""
MATCH (c:Customer {id: $user_id})-[p:PURCHASED]->(product:Product)
RETURN product.name as name, product.brand as brand,
product.category as category, p.purchase_date as date
ORDER BY p.purchase_date DESC
LIMIT $limit
""",
parameters={"user_id": ctx.deps.user_id, "limit": limit},
)
return json.dumps(purchases)
@shopping_agent.tool
async def save_preference(
ctx: RunContext[ShoppingDeps],
preference: str,
category: str,
) -> str:
"""Save a new customer preference learned during conversation."""
await ctx.deps.memory_client.long_term.add_preference(
user_id=ctx.deps.user_id,
preference=preference,
category=category,
confidence=0.85,
metadata={
"source": "conversation",
"session_id": ctx.deps.session_id,
},
)
return f"Saved preference: {category} - {preference}"
@shopping_agent.tool
async def get_similar_recommendations(
ctx: RunContext[ShoppingDeps],
) -> str:
"""Find past successful recommendations for similar queries."""
traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
task=ctx.deps.current_query,
user_id=ctx.deps.user_id,
limit=3,
success_only=True,
)
past_recs = []
for trace in traces:
if trace.result and "products" in trace.result:
past_recs.append({
"query": trace.task,
"products_recommended": trace.result["products"][:3],
"customer_response": trace.result.get("feedback", "unknown"),
})
return json.dumps(past_recs)
# --- System Prompt Enhancement ---
@shopping_agent.system_prompt
async def inject_customer_context(ctx: RunContext[ShoppingDeps]) -> str:
"""Inject customer context into system prompt."""
# Get preferences
prefs = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.user_id,
limit=5,
)
if not prefs:
return ""
pref_text = "\n".join([f"- {p.category}: {p.preference}" for p in prefs])
return f"""
## Known Customer Preferences
{pref_text}
Use these preferences to personalize your recommendations.
"""
# --- Main Application ---
async def run_shopping_assistant():
"""Run the shopping assistant with memory."""
memory_client = MemoryClient(
neo4j_uri="bolt://localhost:7687",
neo4j_user="neo4j",
neo4j_password="password",
)
user_id = "CUST-12345"
session_id = f"shop-{datetime.now().strftime('%Y%m%d%H%M%S')}"
print("Shopping Assistant Ready! Type 'quit' to exit.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() == "quit":
break
# Store user message
await memory_client.short_term.add_message(
role="user",
content=user_input,
session_id=session_id,
)
# Start reasoning trace
trace = await memory_client.reasoning.start_trace(
task=user_input,
user_id=user_id,
session_id=session_id,
)
deps = ShoppingDeps(
memory_client=memory_client,
user_id=user_id,
session_id=session_id,
current_query=user_input,
)
try:
result = await shopping_agent.run(user_input, deps=deps)
# Format response
response = f"Here are my recommendations:\n\n"
for i, product in enumerate(result.data.products, 1):
response += f"{i}. {product['name']} - ${product.get('price', 'N/A')}\n"
response += f"\n{result.data.reasoning}"
print(f"\nAssistant: {response}\n")
# Store assistant response
await memory_client.short_term.add_message(
role="assistant",
content=response,
session_id=session_id,
)
# Complete trace
await memory_client.reasoning.complete_trace(
trace_id=trace.id,
outcome="success",
result={
"products": result.data.products,
"personalization": result.data.personalization_used,
},
)
except Exception as e:
print(f"\nAssistant: Sorry, I encountered an error: {e}\n")
await memory_client.reasoning.complete_trace(
trace_id=trace.id,
outcome="failure",
error=str(e),
)
if __name__ == "__main__":
import asyncio
asyncio.run(run_shopping_assistant())
Financial Advisory Agent
Example for a financial services use case:
from dataclasses import dataclass
from pydantic_ai import Agent, RunContext
from pydantic import BaseModel
from neo4j_agent_memory import MemoryClient
class InvestmentRecommendation(BaseModel):
"""Structured investment recommendation."""
recommendation: str
securities: list[dict]
rationale: str
risk_considerations: list[str]
compliance_notes: list[str]
@dataclass
class AdvisoryDeps:
memory_client: MemoryClient
client_id: str
advisor_id: str
session_id: str
advisory_agent = Agent(
"openai:gpt-4o",
deps_type=AdvisoryDeps,
result_type=InvestmentRecommendation,
system_prompt="""
You are a financial advisory assistant helping wealth managers serve their clients.
Important guidelines:
1. Always consider the client's risk profile and investment objectives
2. Reference specific securities with ticker symbols when possible
3. Note any compliance considerations for recommendations
4. Use the client's stated preferences and constraints
5. Document your reasoning for audit purposes
Never provide specific buy/sell recommendations without noting that final
decisions should be reviewed by the advisor.
""",
)
@advisory_agent.tool
async def get_client_profile(ctx: RunContext[AdvisoryDeps]) -> str:
"""Get comprehensive client profile from context graph."""
# Get client entity with related data
profile = await ctx.deps.memory_client.long_term.execute_query(
"""
MATCH (c:Client {id: $client_id})
OPTIONAL MATCH (c)-[:HAS_ACCOUNT]->(a:Account)
OPTIONAL MATCH (a)-[:HOLDS]->(p:Position)-[:IN]->(s:Security)
RETURN c, collect(DISTINCT a) as accounts,
collect(DISTINCT {security: s, shares: p.shares}) as holdings
""",
parameters={"client_id": ctx.deps.client_id},
)
# Get investment preferences
preferences = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.client_id,
categories=["risk_profile", "investment_style", "exclusions", "time_horizon"],
)
return json.dumps({
"profile": profile,
"preferences": [{"category": p.category, "value": p.preference} for p in preferences],
})
@advisory_agent.tool
async def search_securities(
ctx: RunContext[AdvisoryDeps],
query: str,
sector: str | None = None,
asset_class: str | None = None,
) -> str:
"""Search for securities matching criteria."""
filters = {}
if sector:
filters["sector"] = sector
if asset_class:
filters["asset_class"] = asset_class
securities = await ctx.deps.memory_client.long_term.search_entities(
query=query,
entity_type="SECURITY",
property_filter=filters if filters else None,
limit=15,
)
return json.dumps([
{
"name": s.name,
"ticker": s.properties.get("ticker"),
"sector": s.properties.get("sector"),
"asset_class": s.properties.get("asset_class"),
"description": s.description,
}
for s in securities
])
@advisory_agent.tool
async def get_past_recommendations(
ctx: RunContext[AdvisoryDeps],
topic: str,
) -> str:
"""Find past successful recommendations for similar situations."""
traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
task=topic,
user_id=ctx.deps.client_id,
limit=3,
success_only=True,
)
return json.dumps([
{
"topic": t.task,
"recommendation": t.result.get("recommendation"),
"outcome": t.result.get("client_response"),
}
for t in traces
])
@advisory_agent.tool
async def record_compliance_note(
ctx: RunContext[AdvisoryDeps],
note: str,
category: str,
) -> str:
"""Record a compliance-relevant note in the context graph."""
await ctx.deps.memory_client.long_term.add_entity(
name=f"Compliance Note - {datetime.now().isoformat()}",
entity_type="COMPLIANCE_NOTE",
properties={
"client_id": ctx.deps.client_id,
"advisor_id": ctx.deps.advisor_id,
"session_id": ctx.deps.session_id,
"category": category,
"note": note,
"timestamp": datetime.now().isoformat(),
},
)
return f"Recorded compliance note: {category}"
@advisory_agent.system_prompt
async def inject_client_context(ctx: RunContext[AdvisoryDeps]) -> str:
"""Inject client investment profile into system prompt."""
preferences = await ctx.deps.memory_client.long_term.get_preferences(
user_id=ctx.deps.client_id,
)
if not preferences:
return ""
profile_text = "\n".join([f"- {p.category}: {p.preference}" for p in preferences])
return f"""
## Client Investment Profile
{profile_text}
All recommendations must align with this profile.
"""
Best Practices
1. Use Dependency Injection
Keep memory client in dependencies, not global state:
# Good: Memory client in deps
@dataclass
class AgentDeps:
memory_client: MemoryClient
user_id: str
# Avoid: Global memory client
memory_client = MemoryClient(...) # Don't do this
2. Store Messages for Every Turn
Capture the full conversation:
# Store user message before running agent
await memory_client.short_term.add_message(
role="user",
content=user_input,
session_id=session_id,
)
# Run agent
result = await agent.run(user_input, deps=deps)
# Store assistant response after
await memory_client.short_term.add_message(
role="assistant",
content=result.data,
session_id=session_id,
)
3. Use Reasoning Traces
Track agent decisions for improvement:
trace = await memory_client.reasoning.start_trace(
task=user_query,
user_id=user_id,
)
try:
result = await agent.run(user_query, deps=deps)
await memory_client.reasoning.complete_trace(
trace_id=trace.id,
outcome="success",
result={"response": result.data},
)
except Exception as e:
await memory_client.reasoning.complete_trace(
trace_id=trace.id,
outcome="failure",
error=str(e),
)
raise
4. Leverage Dynamic System Prompts
Inject relevant context automatically:
@agent.system_prompt
async def dynamic_context(ctx: RunContext[AgentDeps]) -> str:
# Only fetch what's relevant to current query
if "preference" in ctx.deps.current_query.lower():
prefs = await ctx.deps.memory_client.long_term.get_preferences(...)
return format_preferences(prefs)
return ""