Source code for noob.yaml

"""
Mixin for handling configs stored in yaml
Should be split off into another package :)
"""

import re
import shutil
from importlib.metadata import version
from io import StringIO
from itertools import chain
from pathlib import Path
from typing import Any, ClassVar, Literal, Self, Union, overload

from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    GetCoreSchemaHandler,
    ValidationError,
    field_validator,
)
from pydantic_core import core_schema
from ruamel.yaml import YAML, CommentedMap, CommentToken, RoundTripRepresenter, ScalarNode

from noob.types import AbsoluteIdentifier, ConfigID, ConfigSource, valid_config_id

yaml = YAML()


[docs] class YamlRepresenter(RoundTripRepresenter): """Dumper that can represent extra types like Paths"""
[docs] def represent_path(self, data: Path) -> ScalarNode: """Represent a path as a string""" return self.represent_scalar("tag:yaml.org,2002:str", str(data))
YamlRepresenter.add_representer(type(Path()), YamlRepresenter.represent_path) yaml.Representer = YamlRepresenter
[docs] class YAMLMixin: """ Mixin class that provides :meth:`.from_yaml` and :meth:`.to_yaml` classmethods """ def __init__(self, yaml_source: CommentedMap | None = None, **kwargs: Any): super().__init__(**kwargs) self._yaml_source = yaml_source
[docs] @classmethod def from_yaml(cls: type[Self], file_path: str | Path) -> Self: """Instantiate this class by passing the contents of a yaml file as kwargs""" with open(file_path) as file: config_data = yaml.load(file) return cls(yaml_source=config_data, **config_data)
[docs] def to_yaml(self, path: Path | None = None, **kwargs: Any) -> str: """ Dump the contents of this class to a yaml file, returning the contents of the dumped string """ data_str = self.to_yamls(**kwargs) if path: with open(path, "w") as file: file.write(data_str) return data_str
[docs] def to_yamls(self, **kwargs: Any) -> str: """ Dump the contents of this class to a yaml string Args: **kwargs: passed to :meth:`.BaseModel.model_dump` """ data = self._dump_data(**kwargs) if hasattr(self, "_yaml_source") and self._yaml_source is not None: self._yaml_source.update(data) data = self._yaml_source string_stream = StringIO() yaml.dump(data, string_stream) output_str = string_stream.getvalue() string_stream.close() return output_str
def _dump_data(self, **kwargs: Any) -> dict: data = self.model_dump(**kwargs) if isinstance(self, BaseModel) else self.__dict__ return data
[docs] class ConfigYAMLMixin(BaseModel, YAMLMixin): """ Yaml Mixin class that always puts a header consisting of * `id` - unique identifier for this config * `noob_model` - fully-qualified module path to model class * `noob_version` - version of noob when this model was created at the top of the file. """ model_config = ConfigDict(validate_default=True) noob_id: ConfigID | None = None noob_model: AbsoluteIdentifier = Field(None, validate_default=True) noob_version: str = version("noob") HEADER_FIELDS: ClassVar[tuple[str, ...]] = ("noob_id", "noob_model", "noob_version")
[docs] @classmethod def from_yaml(cls: type[Self], file_path: str | Path) -> Self: """Instantiate this class by passing the contents of a yaml file as kwargs""" file_path = Path(file_path) with open(file_path) as file: config_data = yaml.load(file) # fill in any missing fields in the source file needed for a header config_data = cls._complete_header(config_data, file_path) try: instance = cls(**config_data) instance._yaml_source = config_data except ValidationError: if (backup_path := file_path.with_suffix(".yaml.bak")).exists(): from noob.logging import init_logger init_logger("config").debug( f"Model instantiation failed, restoring modified backup from {backup_path}..." ) shutil.copy(backup_path, file_path) raise return instance
[docs] @classmethod def from_id(cls: type[Self], id: ConfigID) -> Self: """ Instantiate a model from a config `id` specified in one of the .yaml configs in either the user :attr:`.Config.config_dir` or the packaged ``config`` dir. .. note:: this method does not yet validate that the config matches the model loading it """ globs = [src.rglob("*.y*ml") for src in cls.config_sources()] for config_file in chain(*globs): try: file_id = yaml_peek("noob_id", config_file) except KeyError: continue if file_id == id: from noob.logging import init_logger init_logger("config").debug( "Model for %s found at %s", cls._model_name(), config_file ) return cls.from_yaml(config_file) raise KeyError(f"No config with id {id} found in {cls.config_sources()}")
[docs] @classmethod def from_any(cls: type[Self], source: ConfigSource | Self) -> Self: """ Try and instantiate a config model from any supported constructor. Args: source (:class:`.ConfigID`, :class:`.Path`, :class:`.PathLike[str]`): Either * the ``id`` of a config file in the user configs directory or builtin * a relative ``Path`` to a config file, relative to the current working directory * a relative ``Path`` to a config file, relative to the user config directory * an absolute ``Path`` to a config file * an instance of the class to be constructed (returned unchanged) """ if isinstance(source, cls): return source elif isinstance(source, str) and valid_config_id(source): return cls.from_id(source) elif isinstance(source, Path | str): from noob.config import config source = Path(source) if source.suffix in (".yaml", ".yml"): if source.exists(): # either relative to cwd or absolute return cls.from_yaml(source) elif ( not source.is_absolute() and (user_source := config.config_dir / source).exists() ): return cls.from_yaml(user_source) raise ValueError( f"Instance of config model {cls.__name__} could not be instantiated from " f"{source} - id or file not found, or type not supported" )
[docs] @field_validator("noob_model", mode="before") @classmethod def fill_noob_model(cls, v: str | None) -> AbsoluteIdentifier: """Get name of instantiating model, if not provided""" if v is None: v = cls._model_name() return v
[docs] @classmethod def config_sources(cls: type[Self]) -> list[Path]: """ Directories to search for config files, in order of priority such that earlier sources are preferred over later sources. """ from noob.config import Config, get_entrypoint_sources, get_extra_sources return [Config().config_dir, *get_extra_sources(), *get_entrypoint_sources()]
def _dump_data(self, **kwargs: Any) -> dict: """Ensure that header is prepended to model data""" return {**self._yaml_header(self), **super()._dump_data(**kwargs)} @classmethod def _model_name(cls) -> AbsoluteIdentifier: return f"{cls.__module__}.{cls.__name__}" @classmethod def _yaml_header(cls, instance: Self | dict) -> dict: if isinstance(instance, dict): model_id = instance.get("noob_id", None) noob_model = instance.get("noob_model", cls._model_name()) noob_version = instance.get("noob_version", version("noob")) else: model_id = getattr(instance, "noob_id", None) noob_model = getattr(instance, "noob_model", cls._model_name()) noob_version = getattr(instance, "noob_version", version("noob")) if model_id is None: # if missing an id, try and recover with model default cautiously # so we throw the exception during validation and not here, for clarity. model_id = getattr(cls.model_fields.get("noob_id", None), "default", None) if type(model_id).__name__ == "PydanticUndefinedType": model_id = None return { "noob_id": model_id, "noob_model": noob_model, "noob_version": noob_version, } @classmethod def _complete_header( cls: type[Self], data: CommentedMap, file_path: str | Path ) -> CommentedMap: """fill in any missing fields in the source file needed for a header""" file_path = Path(file_path) missing_fields = set(cls.HEADER_FIELDS) - set(data.keys()) keys = tuple(data.keys()) out_of_order = len(keys) >= 3 and keys[0:3] != cls.HEADER_FIELDS if missing_fields or out_of_order: if missing_fields: msg = f"Missing required header fields {missing_fields} in config model " f"{str(file_path)}. Updating file (preserving backup)..." else: msg = f"Header keys were present, but either not at the start of {str(file_path)} " "or in out of order. Updating file (preserving backup)..." from noob.logging import init_logger logger = init_logger(cls.__name__) logger.warning(msg) logger.debug(data) header = cls._yaml_header(data) comment: None | list[CommentToken] = None for i, (key, value) in enumerate(header.items()): if key in data: # pop it, preserving comments that start on following lines # to re-inject after the header block, if present if key in data.ca.items and data.ca.items[key][2].value.startswith("\n"): if comment is None: comment = data.ca.items.pop(key) else: comment[2].value += data.ca.items.pop(key)[2].value del data[key] data.insert(i, key, value) if comment: # insert newline comments after noob_version, # which is the last key in the header block data.ca.items["noob_version"] = comment # data = {**header, **data} shutil.copy(file_path, file_path.with_suffix(".yaml.bak")) with open(file_path, "w") as yfile: yaml.dump(data, yfile) return data @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """ Add before_validator to allow instantiation from id """ def _from_id(value: Union[str, "ConfigYAMLMixin"]) -> "ConfigYAMLMixin": if isinstance(value, str): return cls.from_id(value) else: return value return core_schema.no_info_before_validator_function( _from_id, handler(source_type), # TODO: add this when updating pydantic floor to 2.10 # json_schema_input_schema=core_schema.union_schema( # [handler(source_type), handler(ConfigID)] # ), )
@overload def yaml_peek( key: str, path: str | Path, root: bool = True, first: Literal[True] = True ) -> str: ... @overload def yaml_peek( key: str, path: str | Path, root: bool = True, first: Literal[False] = False ) -> list[str]: ... @overload def yaml_peek( key: str, path: str | Path, root: bool = True, first: bool = True ) -> str | list[str]: ...
[docs] def yaml_peek(key: str, path: str | Path, root: bool = True, first: bool = True) -> str | list[str]: """ Peek into a yaml file without parsing the whole file to retrieve the value of a single key. This function is _not_ designed for robustness to the yaml spec, it is for simple key: value pairs, not fancy shit like multiline strings, tagged values, etc. If you want it to be, then i'm afraid you'll have to make a PR about it. Returns a string no matter what the yaml type is so ya have to do your own casting if you want Args: key (str): The key to peek for path (:class:`pathlib.Path` , str): The yaml file to peek into root (bool): Only find keys at the root of the document (default ``True`` ), otherwise find keys at any level of nesting. first (bool): Only return the first appearance of the key (default). Otherwise return a list of values (not implemented lol) Returns: str """ if root: pattern = re.compile( rf"^(?P<key>{key}):\s*\"*\'*(?P<value>\S.*?)\"*\'*$", flags=re.MULTILINE ) else: pattern = re.compile( rf"^\s*(?P<key>{key}):\s*\"*\'*(?P<value>\S.*?)\"*\'*$", flags=re.MULTILINE ) res: re.Match[str] | None = None if first: with open(path) as yfile: for line in yfile: res = pattern.match(line) if res: break if res is not None: return res.groupdict()["value"] else: with open(path) as yfile: text = yfile.read() matches = [match.groupdict()["value"] for match in pattern.finditer(text)] if matches: return matches raise KeyError(f"Key {key} not found in {path}")