"""Utility functions
"""
import math
import os
from functools import singledispatch
from typing import Any
import numpy as np
import pandas as pd
__all__ = ["load_dataset"]
ANCHOR_DATETIME = np.datetime64("1970-01-01") # I remember the day well
array_like = Any
[docs]def load_dataset(name: str) -> pd.DataFrame:
"""Load dataset from shellplot library
Parameters
----------
name : str
Name of the dataset. Currently, available options are:
- `penguins`
Returns
-------
pd.DataFrame
Pandas dataframe of dataset
"""
module_path = os.path.dirname(__file__)
dataset_path = os.path.join(module_path, "datasets", f"{name}.csv")
return pd.read_csv(dataset_path)
def tolerance_round(x, tol=1e-3):
error = 1.0
decimals = 0
while error > tol:
if decimals == 0:
x_rounded = round(x)
else:
x_rounded = round(x, decimals)
fudge = 1e-9 # protect against zero div
error = (x - x_rounded) / (x + fudge)
decimals += 1
return x_rounded, decimals
def difference_round(val, round_func, max_difference):
for dec in range(10):
rounded = round_func(val, decimals=dec)
if abs(rounded - val) <= max_difference:
return rounded
def round_up(n, decimals=0):
return _round_to_decimals(n=n, decimals=decimals, round_func=math.ceil)
def round_down(n, decimals=0):
return _round_to_decimals(n=n, decimals=decimals, round_func=math.floor)
def _round_to_decimals(n, decimals, round_func):
if decimals == 0: # avoid float div for int rounded value
return round_func(n)
else:
multiplier = 10 ** decimals
return round_func(n * multiplier) / multiplier
def timedelta_round(x):
"""Given a numpy timedelta, find the largest time unit without changing value"""
units = ["Y", "M", "D", "h", "m", "s", "ms", "us", "ns"]
for unit in units:
x_rounded = x.astype(f"timedelta64[{unit}]")
if x_rounded == x: # TODO: apparently raises a # WARNING: ?
return unit
def remove_any_nan(x, y):
"""Given two np.ndarray, remove indeces where any is nan"""
is_any_nan = np.isnan(x) | np.isnan(y)
return x[~is_any_nan], y[~is_any_nan]
@singledispatch
def numpy_2d(x):
"""Reshape and transform various array-like inputs to 2d np arrays"""
@numpy_2d.register
def _(x: np.ndarray):
if len(x.shape) == 1:
return x[np.newaxis]
elif len(x.shape) == 2:
return x
else:
raise ValueError("Array dimensions need to be <= 2!")
@numpy_2d.register
def _(x: pd.DataFrame):
return x.to_numpy().transpose()
@numpy_2d.register(pd.Series)
@numpy_2d.register(pd.Index)
def _(x):
return x.to_numpy()[np.newaxis]
@numpy_2d.register
def _(x: list):
if isinstance(x[0], np.ndarray):
return numpy_1d(x)
elif isinstance(x[0], list):
return np.array([numpy_1d(x) for x in x])
else:
return np.array([numpy_1d((x))])
@singledispatch
def numpy_1d(x):
"""Reshape and transform various array-like inputs to 1d np arrays"""
@numpy_1d.register(np.ndarray)
def _(x):
return x
@numpy_1d.register(pd.Series)
@numpy_1d.register(pd.Index)
def _(x):
return x.to_numpy()
@numpy_1d.register(pd.DataFrame)
def _(x):
return x.to_numpy().squeeze()
@numpy_1d.register(list)
@numpy_1d.register(tuple)
def _(x):
return np.array(x)
@numpy_1d.register(str)
def _(x): # TODO: this should be any non-iterable
return np.array([x])
@singledispatch
def get_label(x):
"""Try to get names out of array-like inputs"""
pass
@get_label.register(pd.DataFrame)
def _(x):
return list(x)
@get_label.register(pd.Series)
def _(x):
return x.name
@singledispatch
def get_index(x):
"""Try to get index out of array-like inputs"""
@get_index.register(pd.Series)
@get_index.register(pd.DataFrame)
def _(x):
return np.array(x.index)
def is_datetime(x):
x = numpy_1d(x)
if x.dtype.kind in np.typecodes["Datetime"]:
return x.dtype
else:
return False
def to_numeric(x):
"""Convert np array to numeric values"""
x = numpy_1d(x)
if is_datetime(x):
return x.astype("datetime64[ns]") - ANCHOR_DATETIME
else:
return x
def to_datetime(x):
return x + ANCHOR_DATETIME