Source code for pyaml.arrays.element_array

import fnmatch
import importlib
from typing import Sequence

import numpy as np

from ..bpm.bpm import BPM
from ..common.element import Element
from ..common.exception import PyAMLException
from ..magnet.cfm_magnet import CombinedFunctionMagnet
from ..magnet.magnet import Magnet
from ..magnet.serialized_magnet import SerializedMagnets


[docs] class ElementArray(list[Element]): """ Class that implements access to an element array Parameters ---------- array_name : str Array name elements : list[Element] Element list, all elements must be attached to the same instance of either a Simulator or a ControlSystem. use_aggregator : bool Use aggregator to increase performance by using paralell access to underlying devices. Example ------- An array can be retrieved from the configuration as in the following example: .. code-block:: python >>> sr = Accelerator.load("acc.yaml") >>> elements = sr.design.get_elements("QuadForTune") """ def __init__(self, array_name: str, elements: list[Element], use_aggregator=True): super().__init__(i for i in elements) self.__name = array_name self.__peer = None self.__use_aggregator = use_aggregator if len(elements) > 0: self.__peer = self[0]._peer if len(self) > 0 else None if self.__peer is None or any([m._peer != self.__peer for m in self]): raise PyAMLException( f"{self.__class__.__name__} {self.get_name()}: " "All elements must be attached to the same instance " "of either a Simulator or a ControlSystem" )
[docs] def get_peer(self): """ Returns the peer (:py:class:`~pyaml.lattice.simulator.Simulator` or :py:class:`~pyaml.control.controlsystem.ControlSystem`) of an element list """ return self.__peer
[docs] def get_name(self) -> str: """ Returns the array name """ return self.__name
[docs] def names(self) -> list[str]: """ Returns the element names """ return [e.get_name() for e in self]
def __create_array(self, array_name: str, element_type: type, elements: list): if element_type is None: element_type = Element if issubclass(element_type, Magnet): m = importlib.import_module("pyaml.arrays.magnet_array") array_class = getattr(m, "MagnetArray", None) return array_class(array_name, elements, self.__use_aggregator) elif issubclass(element_type, BPM): m = importlib.import_module("pyaml.arrays.bpm_array") array_class = getattr(m, "BPMArray", None) return array_class(array_name, elements, self.__use_aggregator) elif issubclass(element_type, CombinedFunctionMagnet): m = importlib.import_module("pyaml.arrays.cfm_magnet_array") array_class = getattr(m, "CombinedFunctionMagnetArray", None) return array_class(array_name, elements, self.__use_aggregator) elif issubclass(element_type, SerializedMagnets): m = importlib.import_module("pyaml.arrays.serialized_magnet_array") array_class = getattr(m, "SerializedMagnetsArray", None) return array_class(array_name, elements, self.__use_aggregator) elif issubclass(element_type, Element): return ElementArray(array_name, elements, self.__use_aggregator) else: raise PyAMLException(f"Unsupported sliced array for type {str(element_type)}") def __eval_field(self, attribute_name: str, element: Element) -> str: function_name = "get_" + attribute_name func = getattr(element, function_name, None) return func() if func is not None else "" def __ensure_compatible_operand(self, other: object) -> "ElementArray": """Validate the operand used for set-like operations between arrays.""" if not isinstance(other, ElementArray): raise TypeError( f"Unsupported operand type(s) for set operation: '{type(self).__name__}' and '{type(other).__name__}'" ) if len(self) > 0 and len(other) > 0: if self.get_peer() is not None and other.get_peer() is not None: if self.get_peer() != other.get_peer(): raise PyAMLException(f"{self.__class__.__name__}: cannot operate on arrays attached to different peers") return other def __auto_array(self, elements: list[Element]): """Create the most specific array type for the given element list. The target element type is the most specific common base class (nearest common ancestor) of all elements. This supports heterogeneous subclasses (e.g., several Magnet subclasses) while still returning a MagnetArray when appropriate. """ if len(elements) == 0: return [] import inspect def mro_as_list(cls: type) -> list[type]: # inspect.getmro returns (cls, ..., object) return list(inspect.getmro(cls)) # Start from the first element MRO as reference order (most specific first). common: list[type] = mro_as_list(type(elements[0])) # Intersect while preserving MRO order from the first element. for e in elements[1:]: mro_set = set(mro_as_list(type(e))) common = [c for c in common if c in mro_set] if not common: break # Pick the first suitable common base within the Element hierarchy. chosen: type = Element for c in common: if c is object: continue if issubclass(c, Element): chosen = c break return self.__create_array("", chosen, elements) def __is_bool_mask(self, other: object) -> bool: """Return True if 'other' looks like a boolean mask (list or numpy array).""" # --- numpy boolean array --- try: if isinstance(other, np.ndarray) and other.dtype == bool: return True except Exception: pass # --- python sequence of bools (but not a string/bytes) --- if isinstance(other, Sequence) and not isinstance(other, (str, bytes, bytearray)): # Avoid treating ElementArray as a mask if isinstance(other, ElementArray): return False # Accept only actual bool-like values try: return all(isinstance(v, bool) for v in other) except TypeError: return False return False
[docs] def __and__(self, other: object): """ Intersection or boolean mask filtering. This operator has two distinct behaviors depending on the type of ``other``. 1) Array intersection If ``other`` is an ElementArray, the result contains elements whose names are present in both arrays. **Example** .. code-block:: python >>> cell1 = sr.live.get_elements("C01") >>> sexts = sr.live.get_magnets("SEXT") >>> cell1_sext = cell1 & sexts 2) Boolean mask filtering If ``other`` is a boolean mask (list[bool] or numpy.ndarray of bool), elements are kept where the mask is True. **Example** .. code-block:: python >>> mask = cell1.mask_by_type(Magnet) >>> magnets = cell1 & mask Returns ------- Array The result is automatically typed according to the most specific common base class of the remaining elements which can be: :py:class:`.BPMArray` or :py:class:`.MagnetArray` or :py:class:`.CombinedFunctionMagnetArray` or :py:class:`.SerializedMagnetsArray` or :py:class:`.ElementArray`. """ # --- mask filtering --- if self.__is_bool_mask(other): mask = list(other) # works for list/tuple and numpy arrays if len(mask) != len(self): raise ValueError( f"{self.__class__.__name__}: mask length ({len(mask)}) does not match array length ({len(self)})" ) res = [e for e, keep in zip(self, mask, strict=True) if bool(keep)] return self.__auto_array(res) # --- array intersection --- other_arr = self.__ensure_compatible_operand(other) other_names = {e.get_name() for e in other_arr} res = [e for e in self if e.get_name() in other_names] return self.__auto_array(res)
def __rand__(self, other: object): # Support "array on the right" for array operands; for masks, we don't enforce # commutativity. if isinstance(other, ElementArray): return other.__and__(self) return NotImplemented
[docs] def __sub__(self, other: object): """ Difference or boolean mask removal. This operator has two behaviors depending on the type of ``other``. 1) Array difference If ``other`` is an ElementArray, the result contains elements whose names are present in ``self`` but not in ``other``. **Example** .. code-block:: python >>> hvcorr = sr.live.get_magnets("HVCORR") >>> hcorr = sr.live.get_magnets("HCORR") >>> vcorr_only = hvcorr - hcorr 2) Boolean mask removal If ``other`` is a boolean mask (list[bool] or numpy.ndarray of bool), elements are removed where the mask is True. This is the inverse of ``& mask``. **Example** .. code-block:: python >>> mask = cell1.mask_by_type(Magnet) >>> non_magnets = cell1 - mask Returns ------- Array The result is automatically typed according to the most specific common base class of the remaining elements which can be: :py:class:`.BPMArray` or :py:class:`.MagnetArray` or :py:class:`.CombinedFunctionMagnetArray` or :py:class:`.SerializedMagnetsArray` or :py:class:`.ElementArray`. """ # --- mask removal --- if self.__is_bool_mask(other): mask = list(other) if len(mask) != len(self): raise ValueError( f"{self.__class__.__name__}: mask length ({len(mask)}) does not match array length ({len(self)})" ) res = [e for e, remove in zip(self, mask, strict=True) if not bool(remove)] return self.__auto_array(res) # --- array difference --- other_arr = self.__ensure_compatible_operand(other) other_names = {e.get_name() for e in other_arr} res = [e for e in self if e.get_name() not in other_names] return self.__auto_array(res)
[docs] def __or__(self, other: object): """ Union between two ElementArray instances. Elements are combined using their names as identity. Order is stable: elements from ``self`` first, followed by elements from ``other`` that are not already present. Example ------- .. code-block:: python >>> hcorr = sr.live.get_magnets("HCORR") >>> vcorr = sr.live.get_magnets("VCORR") >>> all_corr = hcorr | vcorr Returns ------- Array The result is automatically typed according to the most specific common base class of the remaining elements which can be: :py:class:`.BPMArray` or :py:class:`.MagnetArray` or :py:class:`.CombinedFunctionMagnetArray` or :py:class:`.SerializedMagnetsArray` or :py:class:`.ElementArray`. """ other_arr = self.__ensure_compatible_operand(other) seen: set[str] = set() res: list[Element] = [] for e in self: name = e.get_name() if name not in seen: res.append(e) seen.add(name) for e in other_arr: name = e.get_name() if name not in seen: res.append(e) seen.add(name) return self.__auto_array(res)
def __ror__(self, other: object): if isinstance(other, ElementArray): return other.__or__(self) return NotImplemented
[docs] def __add__(self, other: object): """ Alias for the union operator ``|``. Example ------- .. code-block:: python >>> all_corr = hcorr + vcorr Returns ------- Array The result is automatically typed according to the most specific common base class of the remaining elements which can be: :py:class:`.BPMArray` or :py:class:`.MagnetArray` or :py:class:`.CombinedFunctionMagnetArray` or :py:class:`.SerializedMagnetsArray` or :py:class:`.ElementArray`. """ return self.__or__(other)
def __radd__(self, other: object): if isinstance(other, ElementArray): return other.__add__(self) return NotImplemented
[docs] def mask_by_type(self, element_type: type) -> list[bool]: """Return a boolean mask indicating which elements are instances of the given type. Parameters ---------- element_type : type The class to test against (e.g., Magnet). Returns ------- list[bool] A list of booleans where True indicates the element is an instance of the given type (including subclasses). """ if not isinstance(element_type, type): raise TypeError(f"{self.__class__.__name__}: element_type must be a type") return [isinstance(e, element_type) for e in self]
[docs] def of_type(self, element_type: type): """Return a new array containing only elements of the given type. The resulting array is automatically typed according to the most specific common base class of the filtered elements. Parameters ---------- element_type : type The class to filter by (e.g., Magnet). Returns ------- ElementArray or specialized array An auto-typed array containing only matching elements. Returns [] if no elements match. """ if not isinstance(element_type, type): raise TypeError(f"{self.__class__.__name__}: element_type must be a type") filtered = [e for e in self if isinstance(e, element_type)] return self.__auto_array(filtered)
[docs] def exclude_type(self, element_type): mask = self.mask_by_type(element_type) return self - mask
def __getitem__(self, key): if isinstance(key, slice): # Slicing element_type = None r = [] for i in range(*key.indices(len(self))): if element_type is None: element_type = type(self[i]) elif not isinstance(self[i], element_type): element_type = Element # Fall back to element r.append(self[i]) return self.__create_array("", element_type, r) elif isinstance(key, str): fields = key.split(":") if len(fields) <= 1: # Selection by name element_type = None r = [] for e in self: if fnmatch.fnmatch(e.get_name(), key): if element_type is None: element_type = type(e) elif not isinstance(e, element_type): element_type = Element # Fall back to element r.append(e) else: # Selection by fields element_type = None r = [] for e in self: txt = self.__eval_field(fields[0], e) if fnmatch.fnmatch(txt, fields[1]): if element_type is None: element_type = type(e) elif not isinstance(e, element_type): element_type = Element # Fall back to element r.append(e) return self.__create_array("", element_type, r) else: # Default to super selection return super().__getitem__(key)