Use with PydanticAI

How to integrate neo4j-agent-memory with PydanticAI agents to build memory-enabled applications with a persistent context graph.

Overview

PydanticAI is a Python agent framework that provides type-safe tools, dependency injection, and structured outputs. Integrating with neo4j-agent-memory adds persistent memory capabilities, enabling agents to remember past interactions and build knowledge over time.

PydanticAI + Context Graph Architecture
┌─────────────────────────────────────────────────────┐
│                  PydanticAI Agent                   │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐ │
│  │   Tools     │  │  System     │  │  Response   │ │
│  │             │  │  Prompt     │  │  Model      │ │
│  └──────┬──────┘  └──────┬──────┘  └─────────────┘ │
│         │                │                          │
│         ▼                ▼                          │
│  ┌─────────────────────────────────────────────┐   │
│  │           AgentDeps (Dependencies)          │   │
│  │  • MemoryClient                             │   │
│  │  • user_id, session_id                      │   │
│  │  • current_trace_id                         │   │
│  └──────────────────────┬──────────────────────┘   │
└─────────────────────────┼───────────────────────────┘
                          │
                          ▼
┌─────────────────────────────────────────────────────┐
│               Neo4j Context Graph                   │
│  ┌───────────┐  ┌───────────┐  ┌───────────┐       │
│  │ Messages  │  │ Entities  │  │ Traces    │       │
│  │ Sessions  │  │ Prefs     │  │ Steps     │       │
│  └───────────┘  └───────────┘  └───────────┘       │
└─────────────────────────────────────────────────────┘

Prerequisites

  • Python 3.10+

  • neo4j-agent-memory and pydantic-ai installed

  • Neo4j database running

  • OpenAI API key (or other LLM provider)

pip install neo4j-agent-memory pydantic-ai

Basic Integration

Define Agent Dependencies

Create a dependency class that includes the memory client:

from dataclasses import dataclass
from neo4j_agent_memory import MemoryClient

@dataclass
class AgentDeps:
    """Dependencies available to all agent tools and prompts."""
    memory_client: MemoryClient
    user_id: str
    session_id: str
    current_query: str | None = None

Create Memory-Enabled Agent

from pydantic_ai import Agent, RunContext

agent = Agent(
    "openai:gpt-4o",
    deps_type=AgentDeps,
    system_prompt="""
    You are a helpful shopping assistant. Use the available tools to:
    1. Search for products matching customer requests
    2. Retrieve customer preferences to personalize recommendations
    3. Access conversation history for context

    Always consider the customer's stated preferences when making recommendations.
    """,
)

Add Memory Tools

import json

@agent.tool
async def search_messages(
    ctx: RunContext[AgentDeps],
    query: str,
    limit: int = 5,
) -> str:
    """Search past conversation messages for relevant context."""
    messages = await ctx.deps.memory_client.short_term.search_messages(
        query=query,
        session_id=ctx.deps.session_id,
        limit=limit,
    )

    return json.dumps([
        {"role": m.role, "content": m.content[:200]}
        for m in messages
    ])


@agent.tool
async def get_preferences(
    ctx: RunContext[AgentDeps],
    category: str | None = None,
) -> str:
    """Get customer preferences from memory."""
    preferences = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.user_id,
        category=category,
    )

    return json.dumps([
        {"category": p.category, "preference": p.preference}
        for p in preferences
    ])


@agent.tool
async def search_entities(
    ctx: RunContext[AgentDeps],
    query: str,
    entity_type: str | None = None,
    limit: int = 10,
) -> str:
    """Search the context graph for relevant entities."""
    entities = await ctx.deps.memory_client.long_term.search_entities(
        query=query,
        entity_type=entity_type,
        limit=limit,
    )

    return json.dumps([
        {
            "name": e.name,
            "type": e.type,
            "description": e.description,
        }
        for e in entities
    ])

Dynamic System Prompt with Memory

Inject memory context into the system prompt:

@agent.system_prompt
async def add_memory_context(ctx: RunContext[AgentDeps]) -> str:
    """Dynamically add memory context to system prompt."""
    parts = []

    # Add user preferences
    preferences = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.user_id,
        limit=10,
    )

    if preferences:
        pref_list = [f"- {p.category}: {p.preference}" for p in preferences]
        parts.append(f"## Customer Preferences\n{chr(10).join(pref_list)}")

    # Add relevant past interactions
    if ctx.deps.current_query:
        traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
            task=ctx.deps.current_query,
            user_id=ctx.deps.user_id,
            limit=2,
            success_only=True,
        )

        if traces:
            trace_items = [
                f"- Query: {t.task[:50]}..., Outcome: {t.result.get('summary', 'success')}"
                for t in traces
            ]
            parts.append(f"## Relevant Past Interactions\n{chr(10).join(trace_items)}")

    return "\n\n".join(parts) if parts else ""

Run the Agent

async def main():
    # Initialize memory client
    memory_client = MemoryClient(
        neo4j_uri="bolt://localhost:7687",
        neo4j_user="neo4j",
        neo4j_password="password",
    )

    user_id = "CUST-12345"
    session_id = f"session-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    user_query = "I need new running shoes for marathon training"

    # Create dependencies
    deps = AgentDeps(
        memory_client=memory_client,
        user_id=user_id,
        session_id=session_id,
        current_query=user_query,
    )

    # Store user message
    await memory_client.short_term.add_message(
        role="user",
        content=user_query,
        session_id=session_id,
        metadata={"user_id": user_id},
    )

    # Run agent
    result = await agent.run(user_query, deps=deps)

    # Store assistant response
    await memory_client.short_term.add_message(
        role="assistant",
        content=result.data,
        session_id=session_id,
        metadata={"user_id": user_id},
    )

    print(result.data)

if __name__ == "__main__":
    import asyncio
    asyncio.run(main())

Ecommerce Shopping Assistant

A complete example of a memory-enabled shopping assistant:

from dataclasses import dataclass, field
from datetime import datetime
import json

from pydantic_ai import Agent, RunContext
from pydantic import BaseModel
from neo4j_agent_memory import MemoryClient


# --- Response Models ---

class ProductRecommendation(BaseModel):
    """Structured product recommendation."""
    products: list[dict]
    reasoning: str
    personalization_used: list[str]


# --- Dependencies ---

@dataclass
class ShoppingDeps:
    memory_client: MemoryClient
    user_id: str
    session_id: str
    current_query: str = ""


# --- Agent Definition ---

shopping_agent = Agent(
    "openai:gpt-4o",
    deps_type=ShoppingDeps,
    result_type=ProductRecommendation,
    system_prompt="""
    You are a personal shopping assistant for an online retailer.

    Your goals:
    1. Understand what the customer is looking for
    2. Use their preferences and history to personalize recommendations
    3. Provide helpful, relevant product suggestions
    4. Remember important details for future interactions

    Always explain why you're recommending specific products based on
    the customer's preferences and past behavior.
    """,
)


# --- Memory Tools ---

@shopping_agent.tool
async def search_products(
    ctx: RunContext[ShoppingDeps],
    query: str,
    category: str | None = None,
    brand: str | None = None,
    max_price: float | None = None,
    limit: int = 10,
) -> str:
    """Search product catalog. Use customer preferences to filter when relevant."""

    # Build property filter
    filters = {}
    if brand:
        filters["brand"] = brand
    if max_price:
        filters["price"] = {"$lte": max_price}

    products = await ctx.deps.memory_client.long_term.search_entities(
        query=query,
        entity_type="PRODUCT",
        property_filter=filters if filters else None,
        limit=limit,
    )

    return json.dumps([
        {
            "name": p.name,
            "brand": p.properties.get("brand"),
            "price": p.properties.get("price"),
            "rating": p.properties.get("rating"),
            "description": p.description,
        }
        for p in products
    ])


@shopping_agent.tool
async def get_customer_preferences(
    ctx: RunContext[ShoppingDeps],
) -> str:
    """Get all stored preferences for the current customer."""

    preferences = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.user_id,
    )

    # Group by category
    by_category = {}
    for pref in preferences:
        if pref.category not in by_category:
            by_category[pref.category] = []
        by_category[pref.category].append(pref.preference)

    return json.dumps(by_category)


@shopping_agent.tool
async def get_purchase_history(
    ctx: RunContext[ShoppingDeps],
    limit: int = 10,
) -> str:
    """Get customer's recent purchase history from the context graph."""

    # Query the context graph for purchases
    purchases = await ctx.deps.memory_client.long_term.execute_query(
        """
        MATCH (c:Customer {id: $user_id})-[p:PURCHASED]->(product:Product)
        RETURN product.name as name, product.brand as brand,
               product.category as category, p.purchase_date as date
        ORDER BY p.purchase_date DESC
        LIMIT $limit
        """,
        parameters={"user_id": ctx.deps.user_id, "limit": limit},
    )

    return json.dumps(purchases)


@shopping_agent.tool
async def save_preference(
    ctx: RunContext[ShoppingDeps],
    preference: str,
    category: str,
) -> str:
    """Save a new customer preference learned during conversation."""

    await ctx.deps.memory_client.long_term.add_preference(
        user_id=ctx.deps.user_id,
        preference=preference,
        category=category,
        confidence=0.85,
        metadata={
            "source": "conversation",
            "session_id": ctx.deps.session_id,
        },
    )

    return f"Saved preference: {category} - {preference}"


@shopping_agent.tool
async def get_similar_recommendations(
    ctx: RunContext[ShoppingDeps],
) -> str:
    """Find past successful recommendations for similar queries."""

    traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
        task=ctx.deps.current_query,
        user_id=ctx.deps.user_id,
        limit=3,
        success_only=True,
    )

    past_recs = []
    for trace in traces:
        if trace.result and "products" in trace.result:
            past_recs.append({
                "query": trace.task,
                "products_recommended": trace.result["products"][:3],
                "customer_response": trace.result.get("feedback", "unknown"),
            })

    return json.dumps(past_recs)


# --- System Prompt Enhancement ---

@shopping_agent.system_prompt
async def inject_customer_context(ctx: RunContext[ShoppingDeps]) -> str:
    """Inject customer context into system prompt."""

    # Get preferences
    prefs = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.user_id,
        limit=5,
    )

    if not prefs:
        return ""

    pref_text = "\n".join([f"- {p.category}: {p.preference}" for p in prefs])

    return f"""
## Known Customer Preferences
{pref_text}

Use these preferences to personalize your recommendations.
"""


# --- Main Application ---

async def run_shopping_assistant():
    """Run the shopping assistant with memory."""

    memory_client = MemoryClient(
        neo4j_uri="bolt://localhost:7687",
        neo4j_user="neo4j",
        neo4j_password="password",
    )

    user_id = "CUST-12345"
    session_id = f"shop-{datetime.now().strftime('%Y%m%d%H%M%S')}"

    print("Shopping Assistant Ready! Type 'quit' to exit.\n")

    while True:
        user_input = input("You: ").strip()

        if user_input.lower() == "quit":
            break

        # Store user message
        await memory_client.short_term.add_message(
            role="user",
            content=user_input,
            session_id=session_id,
        )

        # Start reasoning trace
        trace = await memory_client.reasoning.start_trace(
            task=user_input,
            user_id=user_id,
            session_id=session_id,
        )

        deps = ShoppingDeps(
            memory_client=memory_client,
            user_id=user_id,
            session_id=session_id,
            current_query=user_input,
        )

        try:
            result = await shopping_agent.run(user_input, deps=deps)

            # Format response
            response = f"Here are my recommendations:\n\n"
            for i, product in enumerate(result.data.products, 1):
                response += f"{i}. {product['name']} - ${product.get('price', 'N/A')}\n"
            response += f"\n{result.data.reasoning}"

            print(f"\nAssistant: {response}\n")

            # Store assistant response
            await memory_client.short_term.add_message(
                role="assistant",
                content=response,
                session_id=session_id,
            )

            # Complete trace
            await memory_client.reasoning.complete_trace(
                trace_id=trace.id,
                outcome="success",
                result={
                    "products": result.data.products,
                    "personalization": result.data.personalization_used,
                },
            )

        except Exception as e:
            print(f"\nAssistant: Sorry, I encountered an error: {e}\n")

            await memory_client.reasoning.complete_trace(
                trace_id=trace.id,
                outcome="failure",
                error=str(e),
            )


if __name__ == "__main__":
    import asyncio
    asyncio.run(run_shopping_assistant())

Financial Advisory Agent

Example for a financial services use case:

from dataclasses import dataclass
from pydantic_ai import Agent, RunContext
from pydantic import BaseModel
from neo4j_agent_memory import MemoryClient


class InvestmentRecommendation(BaseModel):
    """Structured investment recommendation."""
    recommendation: str
    securities: list[dict]
    rationale: str
    risk_considerations: list[str]
    compliance_notes: list[str]


@dataclass
class AdvisoryDeps:
    memory_client: MemoryClient
    client_id: str
    advisor_id: str
    session_id: str


advisory_agent = Agent(
    "openai:gpt-4o",
    deps_type=AdvisoryDeps,
    result_type=InvestmentRecommendation,
    system_prompt="""
    You are a financial advisory assistant helping wealth managers serve their clients.

    Important guidelines:
    1. Always consider the client's risk profile and investment objectives
    2. Reference specific securities with ticker symbols when possible
    3. Note any compliance considerations for recommendations
    4. Use the client's stated preferences and constraints
    5. Document your reasoning for audit purposes

    Never provide specific buy/sell recommendations without noting that final
    decisions should be reviewed by the advisor.
    """,
)


@advisory_agent.tool
async def get_client_profile(ctx: RunContext[AdvisoryDeps]) -> str:
    """Get comprehensive client profile from context graph."""

    # Get client entity with related data
    profile = await ctx.deps.memory_client.long_term.execute_query(
        """
        MATCH (c:Client {id: $client_id})
        OPTIONAL MATCH (c)-[:HAS_ACCOUNT]->(a:Account)
        OPTIONAL MATCH (a)-[:HOLDS]->(p:Position)-[:IN]->(s:Security)
        RETURN c, collect(DISTINCT a) as accounts,
               collect(DISTINCT {security: s, shares: p.shares}) as holdings
        """,
        parameters={"client_id": ctx.deps.client_id},
    )

    # Get investment preferences
    preferences = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.client_id,
        categories=["risk_profile", "investment_style", "exclusions", "time_horizon"],
    )

    return json.dumps({
        "profile": profile,
        "preferences": [{"category": p.category, "value": p.preference} for p in preferences],
    })


@advisory_agent.tool
async def search_securities(
    ctx: RunContext[AdvisoryDeps],
    query: str,
    sector: str | None = None,
    asset_class: str | None = None,
) -> str:
    """Search for securities matching criteria."""

    filters = {}
    if sector:
        filters["sector"] = sector
    if asset_class:
        filters["asset_class"] = asset_class

    securities = await ctx.deps.memory_client.long_term.search_entities(
        query=query,
        entity_type="SECURITY",
        property_filter=filters if filters else None,
        limit=15,
    )

    return json.dumps([
        {
            "name": s.name,
            "ticker": s.properties.get("ticker"),
            "sector": s.properties.get("sector"),
            "asset_class": s.properties.get("asset_class"),
            "description": s.description,
        }
        for s in securities
    ])


@advisory_agent.tool
async def get_past_recommendations(
    ctx: RunContext[AdvisoryDeps],
    topic: str,
) -> str:
    """Find past successful recommendations for similar situations."""

    traces = await ctx.deps.memory_client.reasoning.get_similar_traces(
        task=topic,
        user_id=ctx.deps.client_id,
        limit=3,
        success_only=True,
    )

    return json.dumps([
        {
            "topic": t.task,
            "recommendation": t.result.get("recommendation"),
            "outcome": t.result.get("client_response"),
        }
        for t in traces
    ])


@advisory_agent.tool
async def record_compliance_note(
    ctx: RunContext[AdvisoryDeps],
    note: str,
    category: str,
) -> str:
    """Record a compliance-relevant note in the context graph."""

    await ctx.deps.memory_client.long_term.add_entity(
        name=f"Compliance Note - {datetime.now().isoformat()}",
        entity_type="COMPLIANCE_NOTE",
        properties={
            "client_id": ctx.deps.client_id,
            "advisor_id": ctx.deps.advisor_id,
            "session_id": ctx.deps.session_id,
            "category": category,
            "note": note,
            "timestamp": datetime.now().isoformat(),
        },
    )

    return f"Recorded compliance note: {category}"


@advisory_agent.system_prompt
async def inject_client_context(ctx: RunContext[AdvisoryDeps]) -> str:
    """Inject client investment profile into system prompt."""

    preferences = await ctx.deps.memory_client.long_term.get_preferences(
        user_id=ctx.deps.client_id,
    )

    if not preferences:
        return ""

    profile_text = "\n".join([f"- {p.category}: {p.preference}" for p in preferences])

    return f"""
## Client Investment Profile
{profile_text}

All recommendations must align with this profile.
"""

Best Practices

1. Use Dependency Injection

Keep memory client in dependencies, not global state:

# Good: Memory client in deps
@dataclass
class AgentDeps:
    memory_client: MemoryClient
    user_id: str

# Avoid: Global memory client
memory_client = MemoryClient(...)  # Don't do this

2. Store Messages for Every Turn

Capture the full conversation:

# Store user message before running agent
await memory_client.short_term.add_message(
    role="user",
    content=user_input,
    session_id=session_id,
)

# Run agent
result = await agent.run(user_input, deps=deps)

# Store assistant response after
await memory_client.short_term.add_message(
    role="assistant",
    content=result.data,
    session_id=session_id,
)

3. Use Reasoning Traces

Track agent decisions for improvement:

trace = await memory_client.reasoning.start_trace(
    task=user_query,
    user_id=user_id,
)

try:
    result = await agent.run(user_query, deps=deps)
    await memory_client.reasoning.complete_trace(
        trace_id=trace.id,
        outcome="success",
        result={"response": result.data},
    )
except Exception as e:
    await memory_client.reasoning.complete_trace(
        trace_id=trace.id,
        outcome="failure",
        error=str(e),
    )
    raise

4. Leverage Dynamic System Prompts

Inject relevant context automatically:

@agent.system_prompt
async def dynamic_context(ctx: RunContext[AgentDeps]) -> str:
    # Only fetch what's relevant to current query
    if "preference" in ctx.deps.current_query.lower():
        prefs = await ctx.deps.memory_client.long_term.get_preferences(...)
        return format_preferences(prefs)
    return ""