Source code for PAGOZA.pair_generation

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Protocol, TypeAlias, runtime_checkable

import numpy.typing as npt

from PAGOZA.utils.base_method import MethodWrapper
from PAGOZA.utils.registers import register_method

###### Base classes and protocols ######
ArrayLike: TypeAlias = npt.NDArray[Any]


[docs] class PairGenerationMethod(Enum): GRADIENT = "gradient" LOCAL_NOISE = "local_noise"
[docs] @runtime_checkable class PairGeneratorProtocol(Protocol): def generate_pairs(self, X: ArrayLike) -> ArrayLike: ...
[docs] class BasePairGenerator(ABC):
[docs] def generate_pairs(self, X: ArrayLike) -> ArrayLike: """ Generate pairs from the input data X. Args: X (ArrayLike): Input data for which to generate pairs. Returns: Generated pairs based on the implemented method. """ self._validate_input(X) return self._generate_pairs(X)
def _validate_input(self, X: ArrayLike) -> None: if X.ndim == 0: raise ValueError("X must be at least one-dimensional") @abstractmethod def _generate_pairs(self, X: ArrayLike) -> ArrayLike: raise NotImplementedError
###### Public API ######
[docs] class PairGenerator(MethodWrapper[PairGeneratorProtocol]): GROUP = "pair_generator" METHOD_NAME = "generate_pairs" def __init__(self, method: str | PairGenerationMethod = "gradient"): super().__init__(method, PairGenerationMethod) def generate_pairs(self, X: ArrayLike) -> ArrayLike: return self.run(X)
###### Implementations of pair generators ######
[docs] @register_method(name="gradient", group="pair_generator") class GradientPairGenerator(BasePairGenerator): def _generate_pairs(self, X: ArrayLike) -> ArrayLike: x = self._compute_gradients(X) return x # Implement the logic to generate pairs based on gradients def _compute_gradients(self, X: ArrayLike) -> ArrayLike: # Implement the logic to compute gradients for the input data return X
[docs] @register_method(name="local_noise", group="pair_generator") class LocalNoisePairGenerator(BasePairGenerator): def _generate_pairs(self, X: ArrayLike) -> ArrayLike: # Implement the logic to generate pairs based on local noise return X