Model objects from the model catalog

Models of the GDS Model Catalog are represented as Model objects in the Python client, similar to how there are graph objects. Model objects are typically constructed from training a pipeline or a GraphSAGE model, in which case a reference to the trained model in the form of a Model object is returned.

Once created, the Model objects can be passed as arguments to methods in the Python client, such as the model catalog operations. Additionally, the Model objects have convenience methods allowing for inspection of the models represented without explicitly involving the model catalog.

In the examples below we assume that we have an instantiated GraphDataScience object called gds. Read more about this in Getting started.

1. Constructing a model object

The primary way to construct a model object is through training a model. There are two types of models: pipeline models and GraphSAGE models. In order to train a pipeline model, a pipeline must first be created and configured Read more about how to operate pipelines in Machine learning pipelines, including examples of using pipeline models. In this section, we will exemplify creating and using a GraphSAGE model object.

First, we introduce a small road-network graph:

    (a:City {name: "New York City", settled: 1624}),
    (b:City {name: "Philadelphia", settled: 1682}),
    (c:City:Capital {name: "Washington D.C.", settled: 1790}),
    (d:City {name: "Baltimore", settled: 1729}),
    (e:City {name: "Atlantic City", settled: 1854}),
    (f:City {name: "Boston", settled: 1822}),

    (a)-[:ROAD {cost: 50}]->(b),
    (a)-[:ROAD {cost: 50}]->(c),
    (a)-[:ROAD {cost: 100}]->(d),
    (b)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 40}]->(d),
    (c)-[:ROAD {cost: 80}]->(e),
    (d)-[:ROAD {cost: 30}]->(e),
    (d)-[:ROAD {cost: 80}]->(f),
    (e)-[:ROAD {cost: 40}]->(f);
G, project_result = gds.graph.project(
    {"City": {"properties": ["settled"]}},
    {"ROAD": {"properties": ["cost"]}}

assert G.relationship_count() == 9

Now we can use the graph G to train a GraphSage model.

model, train_result = gds.beta.graphSage.train(G, modelName="city-representation", featureProperties=["settled"], randomSeed=42)

assert train_result["modelInfo"]["metrics"]["ranEpochs"] == 1

where model is the model object, and res is a pandas Series containing metadata from the underlying procedure call.

Similarly, we can also get model objects from training machine learning pipelines.

To get a model object that represents a model that has already been trained and is present in the model catalog, one can call the client-side only get method and passing it a name:

model = gds.model.get("city-representation")

assert == "city-representation"

The get method does not use any tier prefix because it is not associated to any tier. It only exists in the client and does not have a corresponding Cypher procedure.

2. Inspecting a model object

There are convenience methods on all model objects that let us extract information about the represented model.

Table 1. Model object methods
Name Arguments Return type Description




The name of the model as it appears in the model catalog.




The type of model it is, eg. "graphSage".




The configuration used for training the model.




The schema of the graph on which the model was trained.




True if the model is loaded in the in-memory model catalog, False otherwise.




True if the model is stored on disk, False otherwise.




Time when the model was created.




True if the model is shared between users, False otherwise.




True if the model exists in the GDS Model Catalog, False otherwise.


failIfMissing: Optional[bool]


Removes the model from the GDS Model Catalog.

For example, to get the train configuration of our model object model created above, we would do the following:

train_config = model.train_config()

assert train_config["concurrency"] == 4

3. Using a model object

The primary way to use model objects is for prediction. How to do so for GraphSAGE is described below, and on the Machine learning pipelines page for pipelines.

Additionally, model objects can be used as input to GDS Model Catalog operations. For instance, supposing we have our model object model created above, we could:

# Store the model on disk (GDS Enterprise Edition)
_ =

gds.beta.model.drop(model)  # same as model.drop()

# Load the model again for further use

3.1. GraphSAGE

As exemplified above in Constructing a model object, training a GraphSAGE model with the Python client is analogous to its Cypher counterpart.

Once trained, in addition to the methods above, the GraphSAGE model object will have the following methods.

Table 2. GraphSAGE model methods
Name Arguments Return type Description


G: Graph,
config: **kwargs


Predict embeddings for nodes of the input graph and mutate graph with predictions.


G: Graph,
config: **kwargs


Predict embeddings for nodes of the input graph and stream the results.


G: Graph,
config: **kwargs


Predict embeddings for nodes of the input graph and write the results back to the database.




Returns values for the metrics computed when training.

So given the GraphSAGE model model we trained above, we could do the following:

# Make sure our training actually converged
metrics = model.metrics()
assert metrics["didConverge"]

# Predict on `G` and write embedding node properties back to the database
predict_result = model.predict_write(G, writeProperty="embedding")
assert predict_result["nodePropertiesWritten"] == G.node_count()