Model objects from the model catalog

Models of the GDS Model Catalog are represented as Model objects in the Python client, similar to how there are graph objects. Model objects are typically constructed from training a pipeline or a GraphSAGE model, in which case a reference to the trained model in the form of a Model object is returned.

Once created, the Model objects can be passed as arguments to methods in the Python client, such as the model catalog operations. Additionally, the Model objects have convenience methods allowing for inspection of the models represented without explicitly involving the model catalog.

In the examples below we assume that we have an instantiated GraphDataScience object called gds. Read more about this in Getting started.

1. Constructing a model object

The primary way to construct a model object is through training a model. There are two types of models: pipeline models and GraphSAGE models. In order to train a pipeline model, a pipeline must first be created and configured Read more about how to operate pipelines in Machine learning pipelines, including examples of using pipeline models. In this section, we will exemplify creating and using a GraphSAGE model object.

First, we introduce a small road-network graph:

gds.run_cypher(
  """
  CREATE
    (a:City {name: "New York City", settled: 1624}),
    (b:City {name: "Philadelphia", settled: 1682}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790}),
    (d:City {name: "Baltimore", settled: 1729}),
    (e:City {name: "Atlantic City", settled: 1854}),
    (f:City {name: "Boston", settled: 1822}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f);
  """
)
G, project_result = gds.graph.project(
    "road_graph",
    {"City": {"properties": ["settled"]}},
    {"ROAD": {"properties": ["cost"]}}
)

assert G.relationship_count() == 9

Now we can use the graph G to train a GraphSage model.

model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)

assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1

where model is the model object, and res is a pandas Series containing metadata from the underlying procedure call.

Similarly, we can also get model objects from training machine learning pipelines.

To get a model object that represents a model that has already been trained and is present in the model catalog, one can call the client-side only get method and passing it a name:

model = gds.model.get("city-representation")

assert model.name() == "city-representation"

The get method does not use any tier prefix because it is not associated to any tier. It only exists in the client and does not have a corresponding Cypher procedure.

2. Inspecting a model object

There are convenience methods on all model objects that let us extract information about the represented model.

Table 1. Model object methods
Name Arguments Return type Description

name

-

str

The name of the model as it appears in the model catalog.

type

-

str

The type of model it is, eg. "graphSage".

train_config

-

Series

The configuration used for training the model.

graph_schema

-

Series

The schema of the graph on which the model was trained.

loaded

-

bool

True if the model is loaded in the in-memory model catalog, False otherwise.

stored

-

bool

True if the model is stored on disk, False otherwise.

creation_time

-

neo4j.time.Datetime

Time when the model was created.

shared

-

bool

True if the model is shared between users, False otherwise.

exists

-

bool

True if the model exists in the GDS Model Catalog, False otherwise.

drop

failIfMissing: Optional[bool]

Series

Removes the model from the GDS Model Catalog.

For example, to get the train configuration of our model object model created above, we would do the following:

train_config = model.train_config()

assert train_config["concurrency"] == 4

3. Using a model object

The primary way to use model objects is for prediction. How to do so for GraphSAGE is described below, and on the Machine learning pipelines page for pipelines.

Additionally, model objects can be used as input to GDS Model Catalog operations. For instance, supposing we have our model object model created above, we could:

# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)

gds.beta.model.drop(model)  # same as model.drop()

# Load the model again for further use
gds.alpha.model.load(model.name())

3.1. GraphSAGE

As exemplified above in Constructing a model object, training a GraphSAGE model with the Python client is analogous to its Cypher counterpart.

Once trained, in addition to the methods above, the GraphSAGE model object will have the following methods.

Table 2. GraphSAGE model methods
Name Arguments Return type Description

predict_mutate

G: Graph,
config: **kwargs

Series

Predict embeddings for nodes of the input graph and mutate graph with predictions.

predict_stream

G: Graph,
config: **kwargs

DataFrame

Predict embeddings for nodes of the input graph and stream the results.

predict_write

G: Graph,
config: **kwargs

Series

Predict embeddings for nodes of the input graph and write the results back to the database.

metrics

-

Series

Returns values for the metrics computed when training.

So given the GraphSAGE model model we trained above, we could do the following:

# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]

# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()