Knowledge graph embeddings: Training in PyG, prediction with GDS

Open In Colab

This Jupyter notebook is hosted here in the Neo4j Graph Data Science Client Github repository.

The notebook demonstrates how to use the graphdatascience and PyTorch Geometric (PyG) Python libraries to:

  1. Import the FB15k-237 dataset directly into GDS

  2. Train a TransE model with PyG

  3. Make predictions on the data in the database using GDS Knowledge Graph Embeddings functionality

1. Prerequisites

To run this notebook, you’ll need a Neo4j server with a recent GDS version (2.5+ or later) installed.

Additionally, the following Python libraries are required:

2. Setup

We’ll begin by importing our dependencies and establishing a GDS client connection to the database.

%pip install graphdatascience torch torch_geometric
import collections
import os

import pandas as pd
import torch
import torch.optim as optim
from torch_geometric.data import Data, download_url
from torch_geometric.nn import TransE
from tqdm import tqdm

from graphdatascience import GraphDataScience
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
    NEO4J_AUTH = (
        os.environ.get("NEO4J_USER"),
        os.environ.get("NEO4J_PASSWORD"),
    )
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)
from graphdatascience import ServerVersion

# This notebook requires GDS 2.5.0 or later
assert gds.server_version() >= ServerVersion(2, 5, 0)

3. Downloading and Storing the FB15k-237 Dataset in the Database

Download the FB15k-237 dataset Extract the required files: train.txt, valid.txt, and test.txt.

import os
import zipfile

url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
raw_dir = "./data_from_zip"
download_url(f"{url}", raw_dir)

raw_file_names = ["train.txt", "valid.txt", "test.txt"]
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
    for filename in raw_file_names:
        zip_ref.extract(f"Release/{filename}", path=raw_dir)
data_dir = raw_dir + "/" + "Release"

Set a constraint for unique id entries to speed up data uploads.

gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")

Creating Entity Nodes: Create a node with the label Entity. This node should have properties id and text. - Syntax: (:Entity {id: int, text: str})

Creating Relationships for Training with PyG: Based on the training stage, create relationships of type TRAIN, TEST, or VALID. Each of these relationships should have a rel_id property. - Example Syntax: [:TRAIN {rel_id: int}]

Creating Relationships for Prediction with GDS: For the prediction stage, create relationships of a specific type denoted as REL_i. Each of these relationships should have rel_id and text properties. - Example Syntax: [:REL_7 {rel_id: int, text: str}]

rel_types = {
    "train.txt": "TRAIN",
    "valid.txt": "VALID",
    "test.txt": "TEST",
}
rel_id_to_text_dict = {}
rel_type_dict = collections.defaultdict(list)
rel_dict = {}


def process():
    node_dict_ = {}
    for file_name in raw_file_names:
        file_name_path = data_dir + "/" + file_name

        with open(file_name_path, "r") as f:
            data = [x.split("\t") for x in f.read().split("\n")[:-1]]

        list_of_dicts = []
        for i, (src, rel, dst) in enumerate(data):
            if src not in node_dict_:
                node_dict_[src] = len(node_dict_)
            if dst not in node_dict_:
                node_dict_[dst] = len(node_dict_)
            if rel not in rel_dict:
                rel_dict[rel] = len(rel_dict)
                rel_id_to_text_dict[rel_dict[rel]] = rel

            source = node_dict_[src]
            target = node_dict_[dst]
            edge_type = rel_dict[rel]

            rel_type_dict[edge_type].append(
                {
                    "source": source,
                    "target": target,
                }
            )
            list_of_dicts.append(
                {
                    "source": source,
                    "source_text": src,
                    "target": target,
                    "target_text": dst,
                    "rel_id": edge_type,
                }
            )

        rel_type = rel_types[file_name]
        print(f"Writing {len(list_of_dicts)} entities of {rel_type}")
        gds.run_cypher(
            f"""
            UNWIND $ll as l
            MERGE (n:Entity {{id:l.source, text:l.source_text}})
            MERGE (m:Entity {{id:l.target, text:l.target_text}})
            MERGE (n)-[:{rel_type} {{rel_id:l.rel_id}}]->(m)
            """,
            params={"ll": list_of_dicts},
        )

    print("Writing relationships as different relationship types")
    for rel_id, rels in tqdm(rel_type_dict.items()):
        REL_TYPE = f"REL_{rel_id}"
        gds.run_cypher(
            f"""
            UNWIND $ll AS l MATCH (n:Entity {{id:l.source}}), (m:Entity {{id:l.target}})
            MERGE (n)-[:{REL_TYPE} {{rel_id:$rel_id, text:$text}}]->(m)
            """,
            params={"ll": rels, "rel_id": rel_id, "text": rel_id_to_text_dict[rel_id]},
        )


process()

Project all data in graph to get mapping between id and internal nodeId field from database.

node_projection = {"Entity": {"properties": "id"}}
relationship_projection = [
    {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
    {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
    {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
]

ttv_G, result = gds.graph.project(
    "fb15k-graph-ttv",
    node_projection,
    relationship_projection,
)

node_properties = gds.graph.nodeProperties.stream(
    ttv_G,
    ["id"],
    separate_property_columns=True,
)

nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))

4. Training the TransE Model with PyG

Retrieve data from the database, convert it into torch tensors, and format it into a Data structure suitable for training with PyG.

def create_data_from_graph(relationship_type):
    rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
    topology = [
        rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
        rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
    ]
    edge_index = torch.tensor(topology, dtype=torch.long)
    edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
    data = Data(edge_index=edge_index, edge_type=edge_type)
    data.num_nodes = len(nodeId_to_id)
    display(data)
    return data


train_tensor_data = create_data_from_graph("TRAIN")
test_tensor_data = create_data_from_graph("TEST")
val_tensor_data = create_data_from_graph("VALID")

Drop the projected graph to save memory.

gds.graph.drop(ttv_G)

The training process of the TransE model follows the corresponding PyG example.

def train_model_with_pyg():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = TransE(
        num_nodes=train_tensor_data.num_nodes,
        num_relations=train_tensor_data.num_edge_types,
        hidden_channels=50,
    ).to(device)

    loader = model.loader(
        head_index=train_tensor_data.edge_index[0],
        rel_type=train_tensor_data.edge_type,
        tail_index=train_tensor_data.edge_index[1],
        batch_size=1000,
        shuffle=True,
    )

    optimizer = optim.Adam(model.parameters(), lr=0.01)

    def train():
        model.train()
        total_loss = total_examples = 0
        for head_index, rel_type, tail_index in loader:
            optimizer.zero_grad()
            loss = model.loss(head_index, rel_type, tail_index)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * head_index.numel()
            total_examples += head_index.numel()
        return total_loss / total_examples

    @torch.no_grad()
    def test(data):
        model.eval()
        return model.test(
            head_index=data.edge_index[0],
            rel_type=data.edge_type,
            tail_index=data.edge_index[1],
            batch_size=1000,
            k=10,
        )

    # Consider increasing the number of epochs
    epoch_count = 5
    for epoch in range(1, epoch_count):
        loss = train()
        print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
        if epoch % 75 == 0:
            rank, hits = test(val_tensor_data)
            print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")

    torch.save(model, f"./model_{epoch_count}.pt")

    mean_rank, mrr, hits_at_k = test(test_tensor_data)
    print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")

    return model
model = train_model_with_pyg()
# The model can be loaded if it was trained before
# model = torch.load("./model_501.pt")

Extract node embeddings from the trained model and put them into database.

for i in tqdm(range(len(nodeId_to_id))):
    gds.run_cypher(
        "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
        params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
    )

5. Predict Using GDS Knowledge Graph Edge Embeddings Functionality

Select a relationship type for which to make predictions.

relationship_to_predict = "/film/film/genre"
rel_id_to_predict = rel_dict[relationship_to_predict]
rel_label_to_predict = f"REL_{rel_id_to_predict}"

Project the graph with all nodes and existing relationships of the selected type.

G_test, result = gds.graph.project(
    "graph_to_predict_",
    {"Entity": {"properties": ["id", "emb"]}},
    rel_label_to_predict,
)


def print_graph_info(G):
    print(f"Graph '{G.name()}' node count: {G.node_count()}")
    print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
    print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
    print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")


print_graph_info(G_test)

Retrieve the embedding for the selected relationship from the PyG model. Then, create a GDS TransE model using the graph, node embeddings property, and the embedding for the relationship to be predicted.

target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
source_ids_df = gds.run_cypher(
    "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
    params={"node_text_list": source_node_list},
)

Now, we can use the model to make prediction.

result = transe_model.predict_stream(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    top_k=3,
    concurrency=4,
)
print(result)

Augment the predicted result with node identifiers and their text values.

ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))

ids_to_text = gds.run_cypher(
    "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
    params={"ids": ids_in_result},
)

nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))

result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))

print(result)

6. Using Write Mode

Write mode allows you to write results directly to the database as a new relationship type. This approach helps to avoid mapping from nodeId to id.

write_relationship_type = "PREDICTED_" + rel_label_to_predict
result_write = transe_model.predict_write(
    source_node_filter=source_ids_df.nodeId,
    target_node_filter="Entity",
    relationship_type=rel_label_to_predict,
    write_relationship_type=write_relationship_type,
    write_property="transe_score",
    top_k=3,
    concurrency=4,
)

Extract the result from the database.

gds.run_cypher(
    "MATCH (n)-[r:"
    + write_relationship_type
    + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
)
gds.graph.drop(G_test)