# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline config wrapper (router based on 'template_' key)
and pipeline runner.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import (
Annotated,
Any,
Union,
)
from pydantic import (
BaseModel,
Discriminator,
Field,
Tag,
)
from pydantic.v1.utils import deep_update
from typing_extensions import Self
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
AbstractPipelineConfig,
PipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import (
SimpleKGPipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
logger = logging.getLogger(__name__)
def _get_discriminator_value(model: Any) -> PipelineType:
template_ = None
if "template_" in model:
template_ = model["template_"]
if hasattr(model, "template_"):
template_ = model.template_
return PipelineType(template_) or PipelineType.NONE
class PipelineConfigWrapper(BaseModel):
"""The pipeline config wrapper will parse the right pipeline config based on the `template_` field."""
config: Union[
Annotated[PipelineConfig, Tag(PipelineType.NONE)],
Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)],
] = Field(discriminator=Discriminator(_get_discriminator_value))
def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
return self.config.parse(resolved_data)
def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
return self.config.get_run_params(user_input)
[docs]
class PipelineRunner:
"""Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline
Pipeline can be built from:
- A PipelineDefinition (`__init__` method)
- A PipelineConfig (`from_config` method)
- A config file (`from_config_file` method)
"""
def __init__(
self,
pipeline_definition: PipelineDefinition,
config: AbstractPipelineConfig | None = None,
do_cleaning: bool = False,
) -> None:
self.config = config
self.pipeline = Pipeline.from_definition(pipeline_definition)
self.run_params = pipeline_definition.get_run_params()
self.do_cleaning = do_cleaning
@classmethod
def from_config(
cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)
@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
if not isinstance(file_path, str):
file_path = str(file_path)
data = ConfigReader().read(file_path)
return cls.from_config(data, do_cleaning=True)
async def run(self, user_input: dict[str, Any]) -> PipelineResult:
# pipeline_conditional_run_params = self.
if self.config:
run_param = deep_update(
self.run_params, self.config.get_run_params(user_input)
)
else:
run_param = deep_update(self.run_params, user_input)
logger.info(
f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={run_param}"
)
result = await self.pipeline.run(data=run_param)
if self.do_cleaning:
await self.close()
return result
async def close(self) -> None:
logger.debug("PIPELINE_RUNNER: cleaning up (closing instantiated drivers...)")
if self.config:
await self.config.close()