Persisting and sharing machine learning models

Follow along with a notebook in Colab Google Colab

This example shows how to train, save, publish, and drop a machine learning model using the Model Catalog.

Setup

For more information on how to get started using Python, refer to the Connecting with Python tutorial.

pip install graphdatascience
# Import the client
from graphdatascience import GraphDataScience

# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""

# Configure the client with AuraDS-recommended settings
gds = GraphDataScience(
    AURA_CONNECTION_URI,
    auth=(AURA_USERNAME, AURA_PASSWORD),
    aura_ds=True
)

In the following code examples we use the print function to print Pandas DataFrame and Series objects. You can try different ways to print a Pandas object, for instance via the to_string and to_json methods; if you use a JSON representation, in some cases you may need to include a default handler to handle Neo4j DateTime objects. Check the Python connection section for some examples.

For more information on how to get started using the Cypher Shell, refer to the Neo4j Cypher Shell tutorial.

Run the following commands from the directory where the Cypher shell is installed.
export AURA_CONNECTION_URI="neo4j+s://xxxxxxxx.databases.neo4j.io"
export AURA_USERNAME="neo4j"
export AURA_PASSWORD=""

./cypher-shell -a $AURA_CONNECTION_URI -u $AURA_USERNAME -p $AURA_PASSWORD

For more information on how to get started using Python, refer to the Connecting with Python tutorial.

pip install neo4j
# Import the driver
from neo4j import GraphDatabase

# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""

# Instantiate the driver
driver = GraphDatabase.driver(
    AURA_CONNECTION_URI,
    auth=(AURA_USERNAME, AURA_PASSWORD)
)
# Import to prettify results
import json

# Import for the JSON helper function
from neo4j.time import DateTime

# Helper function for serializing Neo4j DateTime in JSON dumps
def default(o):
    if isinstance(o, (DateTime)):
        return o.isoformat()

Create an example graph

We start by creating some basic graph data first.

gds.run_cypher("""
    MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
    MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
    MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
    MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
    MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
    MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
    MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

    MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
    MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
    MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
    MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
    MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
    MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
    MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

    RETURN True AS exampleDataCreated
""")
MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

RETURN True AS exampleDataCreated
# Cypher query
create_example_graph_on_disk_query = """
    MERGE (dan:Person:ExampleData   {name: 'Dan',   age: 20, heightAndWeight: [185, 75]})
    MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
    MERGE (matt:Person:ExampleData  {name: 'Matt',  age: 67, heightAndWeight: [170, 80]})
    MERGE (jeff:Person:ExampleData  {name: 'Jeff',  age: 45, heightAndWeight: [192, 85]})
    MERGE (brie:Person:ExampleData  {name: 'Brie',  age: 27, heightAndWeight: [176, 57]})
    MERGE (elsa:Person:ExampleData  {name: 'Elsa',  age: 32, heightAndWeight: [158, 55]})
    MERGE (john:Person:ExampleData  {name: 'John',  age: 35, heightAndWeight: [172, 76]})

    MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
    MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
    MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
    MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
    MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
    MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
    MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
    MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)

    RETURN True AS exampleDataCreated
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(create_example_graph_on_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

We then project an in-memory graph from the data just created.

g, result = gds.graph.project(
    "example_graph_for_graphsage",
    {
        "Person": {
            "label": "ExampleData",
            "properties": ["age", "heightAndWeight"]
        }
    },
    {
        "KNOWS": {
            "type": "KNOWS",
            "orientation": "UNDIRECTED",
            "properties": ["relWeight"]
        }
    }
)

print(result)
CALL gds.graph.project(
  'example_graph_for_graphsage',
  {
    Person: {
      label: 'ExampleData',
      properties: ['age', 'heightAndWeight']
    }
  },
  {
    KNOWS: {
      type: 'KNOWS',
      orientation: 'UNDIRECTED',
      properties: ['relWeight']
    }
  }
)
# Cypher query
create_example_graph_in_memory_query = """
    CALL gds.graph.project(
      'example_graph_for_graphsage',
      {
        Person: {
          label: 'ExampleData',
          properties: ['age', 'heightAndWeight']
        }
      },
      {
        KNOWS: {
          type: 'KNOWS',
          orientation: 'UNDIRECTED',
          properties: ['relWeight']
        }
      }
    )
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(create_example_graph_in_memory_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

Train a model

Machine learning algorithms that support the train mode produce trained models which are stored in the Model Catalog. Similarly, predict procedures can use such trained models to produce predictions. In this example we train a model for the GraphSAGE algorithm using the train mode.

model, result = gds.beta.graphSage.train(
    g,
    modelName="example_graph_model_for_graphsage",
    featureProperties=["age", "heightAndWeight"],
    aggregator="mean",
    activationFunction="sigmoid",
    sampleSizes=[25, 10]
)
  CALL gds.beta.graphSage.train(
    'example_graph_for_graphsage',
    {
      modelName: 'example_graph_model_for_graphsage',
      featureProperties: ['age', 'heightAndWeight'],
      aggregator: 'mean',
      activationFunction: 'sigmoid',
      sampleSizes: [25, 10]
    }
  )
  YIELD modelInfo as info
  RETURN
    info.name as modelName,
    info.metrics.didConverge as didConverge,
    info.metrics.ranEpochs as ranEpochs,
    info.metrics.epochLosses as epochLosses
# Cypher query
train_graph_sage_on_in_memory_graph_query  = """
    CALL gds.beta.graphSage.train(
      'example_graph_for_graphsage',
      {
        modelName: 'example_graph_model_for_graphsage',
        featureProperties: ['age', 'heightAndWeight'],
        aggregator: 'mean',
        activationFunction: 'sigmoid',
        sampleSizes: [25, 10]
      }
    )
    YIELD modelInfo as info
    RETURN
      info.name as modelName,
      info.metrics.didConverge as didConverge,
      info.metrics.ranEpochs as ranEpochs,
      info.metrics.epochLosses as epochLosses
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(train_graph_sage_on_in_memory_graph_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

View the model catalog

We can use the gds.beta.model.list procedure to get information on all the models currently available in the catalog. Along with information on the graph schema, the model name, and the training configuration, the result of the call contains the following fields:

  • loaded: flag denoting if the model is in memory (true) or available on disk (false)

  • stored: flag denoting whether the model has been persisted to disk

  • shared: flag denoting whether the model has been published, making it accessible to all users

results = gds.beta.model.list()

print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query  = """
    CALL gds.beta.model.list()
"""

# Create the driver session
with driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

Save a model to disk

The gds.alpha.model.store procedure can be used to persist a model to disk. This is useful both to keep models for later reuse and to free up memory.

Not all the models can be saved to disk. A list of the supported models can be found on the GDS manual.

If a model cannot be saved to disk, it will be lost when the AuraDS instance is restarted.

result = gds.alpha.model.store(model)

print(result)
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
# Cypher query
save_graph_sage_model_to_disk_query = """
    CALL gds.alpha.model.store("example_graph_model_for_graphsage")
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(save_graph_sage_model_to_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

If we list the model catalog again after persisting a model, we can see that the stored flag for that model has been set to true.

results = gds.beta.model.list()

print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query  = """
    CALL gds.beta.model.list()
"""

# Create the driver session
with driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

Share a model with other users

After a model has been created, it can be useful to make it available to other users for different use cases.

A model can only be shared with other users of the same AuraDS instance.

Create a new user

In order to see how this works in practice on AuraDS, we first of all need to create another user to share the model with.

# Switch to the "system" database to run the
# "CREATE USER" admin command
gds.set_database("system")

gds.run_cypher("""
    CREATE USER testUser IF NOT EXISTS
    SET PASSWORD 'password'
    SET PASSWORD CHANGE NOT REQUIRED
""")
:connect system

CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
# Cypher query
create_a_new_user_query = """
    CREATE USER testUser IF NOT EXISTS
    SET PASSWORD 'password'
    SET PASSWORD CHANGE NOT REQUIRED
"""

# Create the driver session using the "system" database
with driver.session(database="system") as session:
    # Run query
    result = session.run(create_a_new_user_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True))

Publish the model

A model can be published (made accessible to other users) using the gds.alpha.model.publish procedure. Upon publication, the model name is updated by appending _public to its original name.

# Switch back to the default "neo4j" database
# to publish the model
gds.set_database("neo4j")

model_public = gds.alpha.model.publish(model)

print(model_public)
:connect neo4j

CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
# Cypher query
publish_graph_sage_model_to_disk_query  = """
    CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
"""

# Create the driver session
with driver.session() as session:
    # Run query
    result = session.run(publish_graph_sage_model_to_disk_query).data()

    # Prettify the result
    print(json.dumps(result, indent=2, sort_keys=True, default=default))

View the model as a different user

In order to verify that the published model is visible to the user we have just created, we need to create a new client (or driver) session. We can then use it to run the gds.beta.model.list procedure again under the new user and verify that the model is included in the list.

test_user_gds = GraphDataScience(
    AURA_CONNECTION_URI,
    auth=("testUser", "password"),
    aura_ds=True
)

results = test_user_gds.beta.model.list()

print(results)
// First, open a new Cypher shell with the following command:
//
// ./cypher-shell -a $AURA_CONNECTION_URI -u testUser -p password

CALL gds.beta.model.list()
test_user_driver = GraphDatabase.driver(
    AURA_CONNECTION_URI,
    auth=("testUser", "password")
)

# Create the driver session
with test_user_driver.session() as session:
    # Run query
    results = session.run(list_model_catalog_query).data()

    # Prettify the results
    print(json.dumps(results, indent=2, sort_keys=True, default=default))

Cleanup

The in-memory graphs, the data in the Neo4j database, the models, and the test user can now all be deleted.

# Delete the example dataset
gds.run_cypher("""
    MATCH (example:ExampleData)
    DETACH DELETE example
""")

# Delete the projected graph from memory
gds.graph.drop(g)

# Drop the model from memory
gds.beta.model.drop(model_public)

# Delete the model from disk
gds.alpha.model.delete(model_public)

# Switch to the "system" database to delete the example user
gds.set_database("system")

gds.run_cypher("""
    DROP USER testUser
""")
// Delete the example dataset from the database
MATCH (example:ExampleData)
DETACH DELETE example;

// Delete the projected graph from memory
CALL gds.graph.drop("example_graph_for_graphsage");

// Drop the model from memory
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public");

// Delete the model from disk
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public");

// Delete the example user
DROP USER testUser;
# Delete the example dataset from the database
delete_example_graph_query = """
    MATCH (example:ExampleData)
    DETACH DELETE example
"""

# Delete the projected graph from memory
drop_in_memory_graph_query = """
    CALL gds.graph.drop("example_graph_for_graphsage")
"""

# Drop the model from memory
drop_example_models_query = """
    CALL gds.beta.model.drop("example_graph_model_for_graphsage_public")
"""

# Delete the model from disk
delete_example_models_query = """
    CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public")
"""

# Delete the example user
drop_example_user_query = """
    DROP USER testUser
"""

# Create the driver session
with driver.session() as session:
    # Run queries
    print(session.run(delete_example_graph_query).data())
    print(session.run(drop_in_memory_graph_query).data())
    print(session.run(drop_example_models_query).data())
    print(session.run(delete_example_models_query).data())

# Create another driver session on the system database
# to drop the test user
with driver.session(database='system') as session:
    print(session.run(drop_example_user_query).data())

driver.close()
test_user_driver.close()

Closing the connection

The connection should always be closed when no longer needed.

Although the GDS client automatically closes the connection when the object is deleted, it is good practice to close it explicitly.

# Close the client connection
gds.close()
# Close the driver connection
driver.close()

References

Cypher