from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple
import pandas as pd
from pandas import DataFrame
from kanon.utils import Sign
from kanon.utils.types.number_types import Real
__all__ = ["Symmetry", "OutOfBoundsOriginError", "OverlappingSymmetryError"]
[docs]@dataclass
class Symmetry:
"""Defines a symmetry strategy that can be applied on a `~pandas.DataFrame`
from a specified source interval to one or multiple target keys
>>> df = DataFrame({"val": [5, 9, 2]} ,index=[0,1,3])
>>> sym = Symmetry("mirror")
>>> df.pipe(sym)
val
0 5
1 9
3 2
5 9
6 5
>>> sym = Symmetry("periodic", sign=-1, offset=2)
>>> df.pipe(sym)
val
0 5
1 9
3 2
4 -3
5 -7
7 0
>>> sym = Symmetry("periodic", sign=-1, source=(0,1), targets=[6,10])
>>> df.pipe(sym)
val
0 5
1 9
3 2
6 -5
7 -9
10 -5
11 -9
:param symtype: Type of the symmetry, it can be of the same direction (`periodic`) \
or the oposite (`mirror`)
:type symtype: Literal["periodic", "mirror"]
:param offset: Offset to add to the symmetry values, defaults to 0
:type offset: int, optional
:param sign: Relative signs of the symmetry values from source values, defaults to 1
:type sign: Sign, optional
:param source: Tuple representing the lower and upper bound to take the values \
from, defaults to the whole DataFrame
:type source: Tuple[Real, Real], optional
:param targets: List of keys where the symmetry are pasted, defaults to the end of \
the DataFrame
:type targets: List[int], optional
"""
symtype: Literal["periodic", "mirror"]
offset: int = 0
sign: Sign = 1
source: Optional[Tuple[Real, Real]] = None
targets: Optional[List[Real]] = None
def __post_init__(self):
if self.symtype not in ("periodic", "mirror"):
raise ValueError
if self.source:
if self.source[0] >= self.source[1]:
raise ValueError
[docs] def __call__(self, df: DataFrame):
if len(df) == 0:
return df
if self.source:
if not (
self.source[0] < df.index[-1] >= self.source[1]
and self.source[0] >= df.index[0] < self.source[1]
):
raise OutOfBoundsOriginError
symdf = df.loc[self.source[0] : self.source[1]].copy() # type: ignore
else:
symdf = df.copy()
def apply(x):
return self.sign * x + self.offset
if not self.targets:
if self.symtype == "mirror":
symdf.index = symdf.index.map(lambda x: 2 * symdf.index[-1] - x)
symdf = symdf[:-1][::-1]
elif self.symtype == "periodic":
symdf.index = symdf.index.map(
lambda x: 1 + x + symdf.index[-1] - symdf.index[0]
)
if self.sign == -1 or self.offset:
symdf = symdf.applymap(apply)
df = pd.concat([df, symdf])
else:
for t in self.targets:
tdf = symdf.copy()
if self.symtype == "mirror":
tdf.index = tdf.index.map(lambda x: tdf.index[-1] - x + t)
tdf = tdf[::-1]
else:
tdf.index = tdf.index.map(lambda x: t + x - tdf.index[0])
if self.sign == -1 or self.offset:
tdf = tdf.applymap(apply)
if len(df.index.intersection(tdf.index)) > 0:
raise OverlappingSymmetryError
df = pd.concat([df, tdf])
return df.sort_index()
[docs]class OutOfBoundsOriginError(IndexError):
"""
Catches applying on a DataFrame a symmetry with source values outside
the DataFrame bounds
"""
[docs]class OverlappingSymmetryError(ValueError):
"""
Catches applying on a DataFrame a symmetry with overlapping results
"""