The model object
Models of the GDS Model Catalog are represented as Model
objects in the Python client, similar to how there are graph objects.
Model
objects are typically constructed from training a pipeline or a GraphSAGE model, in which case a reference to the trained model in the form of a Model
object is returned.
Once created, the Model
objects can be passed as arguments to methods in the Python client, such as the model catalog operations.
Additionally, the Model
objects have convenience methods allowing for inspection of the models represented without explicitly involving the model catalog.
In the examples below we assume that we have an instantiated GraphDataScience
object called gds
.
Read more about this in Getting started.
1. Constructing a model object
There are several ways of constructing a model object.
One of the simplest is to train a GraphSAGE model.
Supposing we have a graph G
that have an integer node property "price", we could do the following:
model, res = gds.beta.graphSage.train(G, modelName="my-model", featureProperties=["price"])
where model
is the model object, and res
is a pandas Series
containing metadata from the underlying procedure call.
Similarly, we can also get model objects from training machine learning pipelines.
To get a model object that represents a model that has already been trained and is present in the model catalog, one can call the client-side only get
method and passing it a name:
model = gds.model.get("my-model")
2. Inspecting a model object
There are convenience methods on all model objects that let us extract information about the represented model.
Name | Arguments | Return type | Description |
---|---|---|---|
|
|
|
The name of the model as it appears in the model catalog. |
|
|
|
The type of model it is, eg. "graphSage". |
|
|
|
The configuration used for training the model. |
|
|
|
The schema of the graph on which the model was trained. |
|
|
|
|
|
|
|
|
|
|
|
Time when the model was created. |
|
|
|
|
|
|
|
|
|
|
|
Removes the model from the GDS Model Catalog. |
For example, to get the train configuration of a model object model
, we would do the following:
train_config = model.train_config()
3. Using a model object
The primary way to use model objects is for prediction. How to do so for GraphSAGE is described below, and on the Machine learning pipelines page for pipelines.
Additionally, model objects can be used as input to GDS Model Catalog operations.
For instance, supposing we have a model object model
, we could:
# Store the model on disk (GDS Enterprise Edition)
_ = gds.alpha.model.store(model)
gds.beta.model.drop(model) # same as model.drop()
3.1. GraphSAGE
As exemplified above in Constructing a model object, training a GraphSAGE model with the Python client is analogous to its Cypher counterpart.
Once trained, in addition to the methods above, the GraphSAGE model object will have the following methods.
Name | Arguments | Return type | Description |
---|---|---|---|
|
|
|
Predict embeddings for nodes of the input graph and mutate graph with predictions. |
|
|
|
Predict embeddings for nodes of the input graph and stream the results. |
|
|
|
Predict embeddings for nodes of the input graph and write the results back to the database. |
|
|
|
Returns values for the metrics computed when training. |
Suppose then that we have a trained GraphSAGE model gs_model
and a graph H
for which we would like to derive node embeddings. Then we could do the following:
# Make sure our training actually converged
metrics = gs_model.metrics()
assert metrics["didConverge"]
# Predict on `H` and write embedding node properties back to the database
results = gs_model.predict_write(H, writeProperty="embedding")
assert result["nodePropertiesWritten"] == H.node_count()
Was this page helpful?