Create embeddings with open source libraries

The Python library SentenceTransformers provides pre-trained models to generate embeddings for text and images, and allows you to play with embeddings without needing an account on OpenAI or other proprietary services.

This page assumes you have already imported the recommendations dataset and set up your environment, and shows how to generate and store embeddings for Movie nodes basing on their title and plot.

Embeddings are always generated outside of Neo4j, but stored in the Neo4j database.

Setup environment

As a last setup step, install the sentence-transformers package.

pip install sentence-transformers

Create embeddings for movies

The example below fetches all Movie nodes from the database, generates an embedding for title and plot, and adds that as an extra embedding property to each node.

from sentence_transformers import SentenceTransformer
import neo4j


URI = '<database-uri>'
AUTH = ('<username>', '<password>')
DB_NAME = '<database-name>'  # examples: 'recommendations-5.26', 'neo4j'


def main():
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:  (1)
        driver.verify_connectivity()

        model = SentenceTransformer('all-MiniLM-L6-v2')  # vector size 384  (2)

        batch_size = 100
        batch_n = 1
        movies_with_embeddings = []
        with driver.session(database=DB_NAME) as session:
            # Fetch `Movie` nodes
            result = session.run('MATCH (m:Movie) RETURN m.plot AS plot, m.title AS title')
            for record in result:
                title = record.get('title')
                plot = record.get('plot')

                # Create embedding for title and plot
                if title is not None and plot is not None:
                    movies_with_embeddings.append({
                        'title': title,
                        'plot': plot,
                        'embedding': model.encode(f'''  (3)
                            Title: {title}\n
                            Plot: {plot}
                        '''),
                    })

                # Import when a batch of movies has embeddings ready; flush buffer
                if len(movies_with_embeddings) == batch_size:  (4)
                    import_batch(driver, movies_with_embeddings, batch_n)
                    movies_with_embeddings = []
                    batch_n += 1

            # Flush last batch
            import_batch(driver, movies_with_embeddings, batch_n)

        # Import complete, show counters
        records, _, _ = driver.execute_query('''
        MATCH (m:Movie WHERE m.embedding IS NOT NULL)
        RETURN count(*) AS countMoviesWithEmbeddings, size(m.embedding) AS embeddingSize
        ''', database_=DB_NAME)
        print(f"""
    Embeddings generated and attached to nodes.
    Movie nodes with embeddings: {records[0].get('countMoviesWithEmbeddings')}.
    Embedding size: {records[0].get('embeddingSize')}.
        """)


def import_batch(driver, nodes_with_embeddings, batch_n):
    # Add embeddings to Movie nodes
    driver.execute_query('''  (5)
    UNWIND $movies as movie
    MATCH (m:Movie {title: movie.title, plot: movie.plot})
    CALL db.create.setNodeVectorProperty(m, 'embedding', movie.embedding)
    ''', movies=nodes_with_embeddings, database_=DB_NAME)
    print(f'Processed batch {batch_n}.')


if __name__ == '__main__':
    main()

'''
Movie nodes with embeddings: 9083.
Embedding size: 384.
'''
1 The driver object is the interface to interact with your Neo4j instance. For more information, see Build applications with Neo4j and Python.
2 The model all-MiniLM-L6-V2 maps text into vectors of size 384 (i.e. lists of 384 numbers). You should always use the same model to generate embeddings for a dataset: pick one and stick to it for your whole project.
3 The .encode() method generates an embedding for the given string (title and plot together, in this case).
4 A number of embeddings are collected before a whole batch is submitted to the database. This avoids holding the whole dataset into memory and potential timeouts (especially relevant for larger datasets).
5 The import query sets a new embedding property on each node m, with the embedding vector movie.embedding as value. The Cypher procedure db.create.setNodeVectorProperty stores vector properties more efficiently than if they were stored as lists. To set vector properties on relationships, use db.create.setRelationshipVectorProperty.

With Enterprise Edition, you can avoid calling db.create.setNodeVectorProperty and instead pass the embeddings as driver’s Vector types and set them as properties via the Cypher clause SET.

from sentence_transformers import SentenceTransformer
import neo4j
from neo4j.vector import Vector


URI = '<database-uri>'
AUTH = ('<username>', '<password>')
DB_NAME = '<database-name>'  # examples: 'recommendations-5.26', 'neo4j'


def main():
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
        driver.verify_connectivity()

        model = SentenceTransformer('all-MiniLM-L6-v2')  # vector size 384

        batch_size = 100
        batch_n = 1
        movies_with_embeddings = []
        with driver.session(database=DB_NAME) as session:
            # Fetch `Movie` nodes
            result = session.run('MATCH (m:Movie) RETURN m.plot AS plot, m.title AS title')
            for record in result:
                title = record.get('title')
                plot = record.get('plot')

                # Create embedding for title and plot
                if title is not None and plot is not None:
                    movies_with_embeddings.append({
                        'title': title,
                        'plot': plot,
                        'embedding': Vector(model.encode(f'''
                            Title: {title}\n
                            Plot: {plot}
                        ''')),
                    })

                # Import when a batch of movies has embeddings ready; flush buffer
                if len(movies_with_embeddings) == batch_size:
                    import_batch(driver, movies_with_embeddings, batch_n)
                    movies_with_embeddings = []
                    batch_n += 1

            # Flush last batch
            import_batch(driver, movies_with_embeddings, batch_n)

        # Import complete, show counters
        records, _, _ = driver.execute_query('''
        MATCH (m:Movie WHERE m.embedding IS NOT NULL)
        RETURN count(*) AS countMoviesWithEmbeddings, size(m.embedding) AS embeddingSize
        ''', database_=DB_NAME)
        print(f"""
    Embeddings generated and attached to nodes.
    Movie nodes with embeddings: {records[0].get('countMoviesWithEmbeddings')}.
    Embedding size: {records[0].get('embeddingSize')}.
        """)


def import_batch(driver, nodes_with_embeddings, batch_n):
    # Add embeddings to Movie nodes
    driver.execute_query('''
    UNWIND $movies as movie
    MATCH (m:Movie {title: movie.title, plot: movie.plot})
    SET m.embedding = movie.embedding
    ''', movies=nodes_with_embeddings, database_=DB_NAME)
    print(f'Processed batch {batch_n}.')


if __name__ == '__main__':
    main()

'''
Movie nodes with embeddings: 9083.
Embedding size: 384.
'''

Once embeddings are in the database, you can use them to compare how similar one movie is to another.