Source code for pyaml.configuration.fileloader

# PyAML config file loader
import collections.abc
import io
import json
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Union

import yaml
from yaml import CLoader
from yaml.constructor import ConstructorError
from yaml.loader import SafeLoader

from pyaml.configuration.factory import Factory

from .. import PyAMLException

if TYPE_CHECKING:
    from pyaml.accelerator import Accelerator


logger = logging.getLogger(__name__)

accepted_suffixes = [".yaml", ".yml", ".json"]
FILE_PREFIX = "file:"

ROOT = {"path": Path.cwd().resolve()}


[docs] def set_root_folder(path: Union[str, Path]): """ Set the root path for configuration files. """ ROOT["path"] = Path(path)
[docs] def get_root_folder() -> Path: """ Get the root path for configuration files. """ return ROOT["path"]
[docs] def get_path(p: Path) -> Path: """ Return unchanged input path if it is an absolute path, path relative to root folder otherwise. """ if os.path.isabs(p): return p else: root = get_root_folder() return root / p
[docs] class PyAMLConfigCyclingException(PyAMLException): def __init__(self, error_filename: str, path_stack: list[Path]): self.error_filename = error_filename parent_file_stack = [parent_path.name for parent_path in path_stack] super().__init__(f"Circular file inclusion of {error_filename}. File list before reaching it: {parent_file_stack}") pass
[docs] def load(filename: str, paths_stack: list = None, use_fast_loader: bool = False) -> Union[dict, list]: """Load recursively a configuration setup""" if filename.endswith(".yaml") or filename.endswith(".yml"): l = YAMLLoader(filename, paths_stack, use_fast_loader) elif filename.endswith(".json"): l = JSONLoader(filename, paths_stack, use_fast_loader) else: raise PyAMLException(f"{filename} File format not supported (only .yaml .yml or .json)") return l.load()
# Expand condition
[docs] def hasToLoad(value): return isinstance(value, str) and any(value.endswith(suffix) for suffix in accepted_suffixes)
# Loader base class (nested files expansion)
[docs] class Loader: def __init__(self, filename: str, parent_path_stack: list[Path]): self.path: Path = get_path(filename) self.files_stack: list[Path] = [] if parent_path_stack: if any(self.path.samefile(parent_path) for parent_path in parent_path_stack): raise PyAMLConfigCyclingException(filename, parent_path_stack) self.files_stack.extend(parent_path_stack) self.files_stack.append(self.path) # Recursively expand a dict
[docs] def expand_dict(self, d: dict): for key, value in d.items(): try: if hasToLoad(value): if value.startswith(FILE_PREFIX): # remove prefix stripped_value = value[len(FILE_PREFIX) :] d[key] = str(get_root_folder() / Path(stripped_value)) else: d[key] = load(value, self.files_stack, self.use_fast_loader) else: self.expand(value) except PyAMLConfigCyclingException as pyaml_ex: location = d.pop("__location__", None) field_locations = d.pop("__fieldlocations__", None) location_str = "" if location: file, line, col = location if field_locations and key in field_locations: location = field_locations[key] file, line, col = location location_str = f" in {file} at line {line}, column {col}" raise PyAMLException(f"Circular file inclusion of {pyaml_ex.error_filename}{location_str}") from pyaml_ex
# Recursively expand a list
[docs] def expand_list(self, l: list): idx = 0 while idx < len(l): value = l[idx] if hasToLoad(value): obj = load(value, self.files_stack) if isinstance(obj, list): l[idx : idx + 1] = obj idx += len(obj) else: l[idx] = obj idx += 1 else: self.expand(value) idx += 1
# Recursively expand an object
[docs] def expand(self, obj: Union[dict, list]): if isinstance(obj, dict): self.expand_dict(obj) elif isinstance(obj, list): self.expand_list(obj) return obj
# Load a file
[docs] def load(self) -> Union[dict, list]: raise PyAMLException(str(self.path) + ": load() method not implemented")
[docs] class SafeLineLoader(SafeLoader): def __init__(self, stream): super().__init__(stream) self.filename = stream.name if isinstance(stream, io.TextIOWrapper) else ""
[docs] def construct_mapping(self, node, deep=False): mapping = {} field_mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) if not isinstance(key, collections.abc.Hashable): raise ConstructorError( "while constructing a mapping", node.start_mark, "found unhashable key", key_node.start_mark, ) value = self.construct_object(value_node, deep=deep) mapping[key] = value field_mapping[key] = ( self.filename, value_node.start_mark.line + 1, value_node.start_mark.column + 1, ) # Add location information inside the dict mapping["__location__"] = ( self.filename, node.start_mark.line + 1, node.start_mark.column + 1, ) mapping["__fieldlocations__"] = field_mapping return mapping
# YAML loader
[docs] class YAMLLoader(Loader): def __init__(self, filename: str, parent_paths_stack: list, use_fast_loader: bool): super().__init__(filename, parent_paths_stack) self._loader = SafeLineLoader if not use_fast_loader else CLoader self.use_fast_loader = use_fast_loader
[docs] def load(self) -> Union[dict, list]: logger.log(logging.DEBUG, f"Loading YAML file '{self.path}'") with open(self.path) as file: try: return self.expand(yaml.load(file, Loader=self._loader)) except yaml.YAMLError as e: raise PyAMLException(str(self.path) + ": " + str(e)) from e
# JSON loader
[docs] class JSONLoader(Loader): def __init__(self, filename: str, parent_paths_stack: list, use_fast_loader: bool): super().__init__(filename, parent_paths_stack) self.use_fast_loader = False
[docs] def load(self) -> Union[dict, list]: logger.log(logging.DEBUG, f"Loading JSON file '{self.path}'") with open(self.path) as file: try: return self.expand(json.load(file)) except json.JSONDecodeError as e: raise PyAMLException(str(self.path) + ": " + str(e)) from e