The model object

Models of the GDS Model Catalog are represented as Model objects in the Python client, similar to how there are graph objects. Model objects are typically constructed from training a pipeline or a GraphSAGE model, in which case a reference to the trained model in the form of a Model object is returned.

Once created, the Model objects can be passed as arguments to methods in the Python client, such as the model catalog operations. Additionally, the Model objects have convenience methods allowing for inspection of the models represented without explicitly involving the model catalog.

In the examples below we assume that we have an instantiated GraphDataScience object called gds. Read more about this in Getting started.

1. Constructing a model object

There are several ways of constructing a model object. One of the simplest is to train a GraphSAGE model. Supposing we have a graph G that have an integer node property "price", we could do the following:

model, res = gds.beta.graphSage.train(G, modelName="my-model", featureProperties=["price"])

where model is the model object, and res is a pandas Series containing metadata from the underlying procedure call.

Similarly, we can also get model objects from training machine learning pipelines.

To get a model object that represents a model that has already been trained and is present in the model catalog, one can call the client-side only get method and passing it a name:

model = gds.model.get("my-model")

The get method does not use any tier prefix because it is not associated to any tier. It only exists in the client and does not have a corresponding Cypher procedure.

2. Inspecting a model object

There are convenience methods on all model objects that let us extract information about the represented model.

Table 1. Model object methods
Name Arguments Return type Description

name

-

str

The name of the model as it appears in the model catalog.

type

-

str

The type of model it is, eg. "graphSage".

train_config

-

Series

The configuration used for training the model.

graph_schema

-

Series

The schema of the graph on which the model was trained.

loaded

-

bool

True if the model is loaded in the in-memory model catalog, False otherwise.

stored

-

bool

True if the model is stored on disk, False otherwise.

creation_time

-

neo4j.time.Datetime

Time when the model was created.

shared

-

bool

True if the model is shared between users, False otherwise.

exists

-

bool

True if the model exists in the GDS Model Catalog, False otherwise.

drop

failIfMissing: Optional[bool]

Series

Removes the model from the GDS Model Catalog.

For example, to get the train configuration of a model object model, we would do the following:

train_config = model.train_config()

3. Using a model object

The primary way to use model objects is for prediction. How to do so for GraphSAGE is described below, and on the Machine learning pipelines page for pipelines.

Additionally, model objects can be used as input to GDS Model Catalog operations. For instance, supposing we have a model object model, we could:

# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)

gds.beta.model.drop(model)  # same as model.drop()

3.1. GraphSAGE

As exemplified above in Constructing a model object, training a GraphSAGE model with the Python client is analogous to its Cypher counterpart.

Once trained, in addition to the methods above, the GraphSAGE model object will have the following methods.

Table 2. GraphSAGE model methods
Name Arguments Return type Description

predict_mutate

G: Graph,
config: **kwargs

Series

Predict embeddings for nodes of the input graph and mutate graph with predictions.

predict_stream

G: Graph,
config: **kwargs

DataFrame

Predict embeddings for nodes of the input graph and stream the results.

predict_write

G: Graph,
config: **kwargs

Series

Predict embeddings for nodes of the input graph and write the results back to the database.

metrics

-

Series

Returns values for the metrics computed when training.

Suppose then that we have a trained GraphSAGE model gs_model and a graph H for which we would like to derive node embeddings. Then we could do the following:

# Make sure our training actually converged
metrics = gs_model.metrics()
assert metrics["didConverge"]

# Predict on `H` and write embedding node properties back to the database
result = gs_model.predict_write(H, writeProperty="embedding")
assert result["nodePropertiesWritten"] == H.node_count()