GraphSAGE node classification training

Neo4j Graph Analytics for Snowflake is in Public Preview and is not intended for production use.

GraphSAGE is a graph neural network (GNN) architecture that can be used as a supervised algorithm to predict class labels of nodes in a graph. This section provides instructions for how to use the GraphSAGE endpoint for training a model for node classification using Neo4j Graph Analytics for Snowflake.

Endpoint

The endpoint name is graph.gs_nc_train, and it takes two positional arguments as input.

The first argument is a VARCHAR that specifies which compute pool to use. For this algorithm we strongly recommend using a GPU compute pool, unless the dataset is very small and the model shallow.

The second argument is a JSON configuration map. This JSON must contain the two following keys:

Name

Type

Default

Optional

Description

project

Map

n/a

no

Configuration for the input graph

compute

Map

n/a

no

Configuration for algorithm-specific parameters

Configuration

In this section, we describe the configuration parameters that must be provided to the endpoint.

Input graph configuration

Name

Type

Default

Optional

Description

nodeTables

List

n/a

no

A list of table names, which each represent a node label in the input graph

relationshipTables

Map

n/a

no

A map from table names representing relationship types, to maps of configuration for that relationship type in the input graph (details below)

defaultTablePrefix

String

n/a

yes

A default database and schema prefix to use for table names in the input graph. Should be of the format "<database>.<schema>"

If a defaultTablePrefix is not provided, all table names must be qualified with a database and schema name. That is, they should be given a strings of the format "<database>.<schema>.<table>". If a defaultTablePrefix is provided, table names may also be given as "<table>", in which case the prefix will be prepended to them.

All provided node tables and relationship tables must have unique names. Not only must fully qualified table names be unique, but the table names themselves must also be unique.

Node tables

The nodeTables list of the project map must contain an entry for each type of node in the input graph. In each such table, nodes are represented by rows. There must be at least one column in each table; one that represents the node ID, and this column must be named nodeid (case insensitive). Each node ID must be unique within its table. The type of the nodeid column must be either BIGINT or VARCHAR. In addition to the nodeid column, the table may contain additional columns that represent node properties, for example features of the nodes.

Relationship tables

The relationshipTables map of the project must contain an entry for each type of relationship in the input graph. Each key in the relationshipTables map is the name of the table containing the relationships for one type of relationships, and each value is a map of configuration for that type of relationship. The configuration map for each relationship type looks like the following:

Name

Type

Default

Optional

Description

sourceTable

String

n/a

no

The name of the table that contains the source nodes of the relationships

targetTable

String

n/a

no

The name of the table that contains the target nodes of the relationships

orientation

String

"NATURAL"

yes

How to interpret the orientation (direction) of the provided relationships. Possible values are "NATURAL", "REVERSE" and "UNDIRECTED"

In each provided relationship table, relationships are represented by rows. There are exactly two columns that must be present in each relationship table: sourcenodeid and targetnodeid (both case-insensitive). These specify the source and target nodes of the relationship, respectively, and should correspond to the node IDs in the provided source and target tables. The type of the sourcenodeid and targetnodeid columns must be either BIGINT or VARCHAR.

The orientation parameter specifies how to interpret the direction of the relationships. By default, relationships are interpreted as having the "NATURAL" orientation, meaning that they are assumed to be directed from the source node to the target node. If the orientation is set to "REVERSE", the relationships are interpreted as being directed from the target node to the source node. And if the orientation is set to "UNDIRECTED", the relationships are interpreted as being undirected, meaning that they are symmetric and can be traversed in either direction (independently of which node is the source and which is the target).

Please note that in order for GraphSAGE to properly propagate updates of node embeddings, each type of node must be the target of at least one relationship type. The orientation parameter can be useful to add reverse direction relationships for types of nodes that are only the source of relationships (using the "REVERSE" or "UNDIRECTED" orientations).

Algorithm configuration

This section describes the parameters that are specific to the algorithm itself. These parameters should be provided as part of compute of the top-level JSON object.

The following algorithm parameters can be configured:

Name

Type

Default

Optional

Description

target_label

String

n/a

no

The node label (i.e. type) to train to predict on

target_property

String

n/a

no

The node property to train to predict, represented by a column in the input node table of the specified 'target_label'

modelname

String

n/a

no

The name of the model to train (must be unique)

numEpochs

Integer

n/a

no

The number of epochs to train the model

numSamples

List of Integer

n/a

no

The number of neighbors to sample for each layer. Note that this also determines the number of layers

hiddenChannels

Integer

256

yes

The node embedding dimension of the model layers' outputs

activation

String

"relu"

yes

The activation function to use. Valid values are "relu" and "sigmoid"

aggregator

String

"mean"

yes

The neighborhood embedding aggregator to use. Valid values are "mean" and "max"

learningRate

Float

0.001

yes

The learning rate for the optimizer

dropout

Float

0.1

yes

The dropout probability for each layer. Must be a value >= 0.0 and < 1.0

layerNormalization

Boolean

true

yes

Whether to apply layer normalization between the model layers

epochsPerCheckpoint

Integer

max(numEpochs / 10, 1)

yes

The number of epochs between saving model checkpoints

randomSeed

Integer

A random integer

yes

A number used to seed all randomness of the computation

split_ratios

Map

{"TRAIN": 0.6, "TEST": 0.2, "VALID": 0.2}

yes

The ratios as a map to split the target nodes of the input graph into training, test, and validation sets. The keys must be "TRAIN", "TEST" and "VALID". The sum of the values must be 1.0

epochs_per_val

Integer

0

yes

The number of epochs between evaluating the model on the validation set. If set to 0, the model will not be evaluated on the validation set

train_batch_size

Integer

Automatically inferred

yes

The number of target nodes to train on in each batch. If not provided, the algorithm will automatically infer the maximally allowed batch size within the constraints of available memory

eval_batch_size

Integer

train batch size

yes

The batch size to use for evaluation

class_weights

Boolean or Map

false

yes

Whether to use class weights to balance the training data. If set to true, class weights will be calculated based on the distribution of the target labels in the training set. If set to a map, the map must contain the class weight for each target class label

Example

For our example we will use an IMDB dataset with actors, directors, movies, and genres. These all have keywords associated with them, which we will use as features for the nodes. They are connected by relationships where actors act in movies and directors direct movies. The goal is to predict the genre of movies.

We have a database called imdb that contains the tables:

  • actor with columns nodeid and plot_keywords

  • movie with columns nodeid, plot_keywords and genre

  • director with columns nodeid and plot_keywords

  • acted_in with columns sourcenodeid and targetnodeid that represent actor and movie node IDs

  • directed_in with columns sourcenodeid and targetnodeid that represent director and movie node IDs

The plot_keywords columns contain keywords associated with the nodes, encoded as vectors of floats. The genre column contains the target class labels for the movie nodes, which we want to predict.

You can upload this dataset to your snowflake account by following the instructions at github: neo4j-product-examples/snowflake-graph-analytics.

The training query

In the following query we train a GraphSAGE model for node classification on the dataset. We train for 10 epochs, with two hidden layers, and use class weights to balance the class distribution.

To run the query, there is a required setup of grants for the application, your consumer role and your environment. Please see the Getting started page for more on this.

We also assume that the application name is the default Neo4j_Graph_Analytics. If you chose a different app name during installation, please replace it with that.

CALL Neo4j_Graph_Analytics.graph.gs_nc_train('GPU_NV_S', {
    'project': {
        'defaultTablePrefix': 'imdb.gml',
        'nodeTables': ['actor', 'director', 'movie'],
        'relationshipTables': {
            'acted_in': {
                'sourceTable': 'actor',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            },
            'directed_in': {
                'sourceTable': 'director',
                'targetTable': 'movie',
                'orientation': 'UNDIRECTED'
            }
        }
    },
    'compute': {
        'modelname': 'nc-imdb',
        'numEpochs': 10,
        'numSamples': [20, 20],
        'targetLabel': 'movie',
        'targetProperty': 'genre',
        'classWeights': true
    }
});

The above query should produce a result similar to the one below. The numerical results may vary.

JOB_ID

JOB_START

JOB_END

JOB_RESULT

job_63b8083fc8ef463ab38cd95d2ac345ea

2025-04-29 12:06:28.791

2025-04-29 12:07:10.318

{ "metrics": { "test_acc": 0.7441860437393188, "test_f1_macro": 0.7236689925193787, "test_f1_micro": 0.7441860437393188, "train_acc": 0.9911160469055176, "train_f1_macro": 0.9900508522987366, "train_f1_micro": 0.9911160469055176 } }