# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, Tuple
from pydantic import BaseModel, ValidationError, model_validator, validate_call
from typing_extensions import Self
from neo4j_graphrag.exceptions import SchemaValidationError
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types import (
EntityInputType,
RelationInputType,
)
[docs]
class SchemaProperty(BaseModel):
"""
Represents a property on a node or relationship in the graph.
"""
name: str
# See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types
type: Literal[
"BOOLEAN",
"DATE",
"DURATION",
"FLOAT",
"INTEGER",
"LIST",
"LOCAL_DATETIME",
"LOCAL_TIME",
"POINT",
"STRING",
"ZONED_DATETIME",
"ZONED_TIME",
]
description: str = ""
[docs]
class SchemaEntity(BaseModel):
"""
Represents a possible node in the graph.
"""
label: str
description: str = ""
properties: List[SchemaProperty] = []
@classmethod
def from_text_or_dict(cls, input: EntityInputType) -> Self:
if isinstance(input, SchemaEntity):
return input
if isinstance(input, str):
return cls(label=input)
return cls.model_validate(input)
[docs]
class SchemaRelation(BaseModel):
"""
Represents a possible relationship between nodes in the graph.
"""
label: str
description: str = ""
properties: List[SchemaProperty] = []
@classmethod
def from_text_or_dict(cls, input: RelationInputType) -> Self:
if isinstance(input, SchemaRelation):
return input
if isinstance(input, str):
return cls(label=input)
return cls.model_validate(input)
[docs]
class SchemaConfig(DataModel):
"""
Represents possible relationships between entities and relations in the graph.
"""
entities: Dict[str, Dict[str, Any]]
relations: Optional[Dict[str, Dict[str, Any]]]
potential_schema: Optional[List[Tuple[str, str, str]]]
@model_validator(mode="before")
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
entities = data.get("entities", {}).keys()
relations = data.get("relations", {}).keys()
potential_schema = data.get("potential_schema", [])
if potential_schema:
if not relations:
raise SchemaValidationError(
"Relations must also be provided when using a potential schema."
)
for entity1, relation, entity2 in potential_schema:
if entity1 not in entities:
raise SchemaValidationError(
f"Entity '{entity1}' is not defined in the provided entities."
)
if relation not in relations:
raise SchemaValidationError(
f"Relation '{relation}' is not defined in the provided relations."
)
if entity2 not in entities:
raise SchemaValidationError(
f"Entity '{entity2}' is not defined in the provided entities."
)
return data
[docs]
class SchemaBuilder(Component):
"""
A builder class for constructing SchemaConfig objects from given entities,
relations, and their interrelationships defined in a potential schema.
Example:
.. code-block:: python
from neo4j_graphrag.experimental.components.schema import (
SchemaBuilder,
SchemaEntity,
SchemaProperty,
SchemaRelation,
)
from neo4j_graphrag.experimental.pipeline import Pipeline
entities = [
SchemaEntity(
label="PERSON",
description="An individual human being.",
properties=[
SchemaProperty(
name="name", type="STRING", description="The name of the person"
)
],
),
SchemaEntity(
label="ORGANIZATION",
description="A structured group of people with a common purpose.",
properties=[
SchemaProperty(
name="name", type="STRING", description="The name of the organization"
)
],
),
]
relations = [
SchemaRelation(
label="EMPLOYED_BY", description="Indicates employment relationship."
),
]
potential_schema = [
("PERSON", "EMPLOYED_BY", "ORGANIZATION"),
]
pipe = Pipeline()
schema_builder = SchemaBuilder()
pipe.add_component(schema_builder, "schema_builder")
pipe_inputs = {
"schema": {
"entities": entities,
"relations": relations,
"potential_schema": potential_schema,
},
...
}
pipe.run(pipe_inputs)
"""
@staticmethod
def create_schema_model(
entities: List[SchemaEntity],
relations: Optional[List[SchemaRelation]] = None,
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
) -> SchemaConfig:
"""
Creates a SchemaConfig object from Lists of Entity and Relation objects
and a Dictionary defining potential relationships.
Args:
entities (List[SchemaEntity]): List of Entity objects.
relations (List[SchemaRelation]): List of Relation objects.
potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names.
Returns:
SchemaConfig: A configured schema object.
"""
entity_dict = {entity.label: entity.model_dump() for entity in entities}
relation_dict = (
{relation.label: relation.model_dump() for relation in relations}
if relations
else {}
)
try:
return SchemaConfig(
entities=entity_dict,
relations=relation_dict,
potential_schema=potential_schema,
)
except (ValidationError, SchemaValidationError) as e:
raise SchemaValidationError(e)
[docs]
@validate_call
async def run(
self,
entities: List[SchemaEntity],
relations: Optional[List[SchemaRelation]] = None,
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
) -> SchemaConfig:
"""
Asynchronously constructs and returns a SchemaConfig object.
Args:
entities (List[SchemaEntity]): List of Entity objects.
relations (List[SchemaRelation]): List of Relation objects.
potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names.
Returns:
SchemaConfig: A configured schema object, constructed asynchronously.
"""
return self.create_schema_model(entities, relations, potential_schema)