# ***************************************************************************
#  IIIIII NNN  NNN  Copyright (C) 2023 Innovative Networks, Inc.
#    II    NNN  N   All Rights Reserved. Any redistribution or reproduction
#    II    N NN N   of part or all of the content of this program in any form
#    II    N  NNN   without expressed written consent of the copyright holder
#  IIIIII NNN  NNN  is strictly prohibited.  Please contact admins@in-kc.com
#   Be Innovative.  for additional information.
# ***************************************************************************
#  inmon_quant.py
#  Author - Ian Perry <iperry@indigex.com>
#
#  Purpose:  Utility functionality for Python INmon scripts
#
#  Version History:
#       2024.01.29 - Initial Creation
# ***************************************************************************
# pylint: disable=missing-module-docstring
from __future__ import annotations
from functools import total_ordering
import re
import logging
import math

BYTE_TABLE = ['B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']

logger = logging.getLogger(__name__)

@total_ordering
class Quantity:
    """
    A quantity with a unit of measurement

    Raises:
        TypeError: If value is not numeric
        TypeError: If int is other than str|None
    """
    @staticmethod
    def isdatatype(uom: str) -> bool:
        """
        Verifies whether value is a type of data

        Args:
            string (str): A string representation of a quantity with unit of measurement
            value (int|float): A numeric value
            uom (str): A unit of measurement.

        Returns:
            bool: If unit of measurement is a type of data
        """
        return (uom is not None and uom.endswith(('b', 'B', 'ib', 'iB')))

    def __init__(
        self,
        string: str|None = None,
        value: int|float|str|None = None,
        uom: str|None = None
    ) -> None:
        if not isinstance(string, str) and string is not None:
            raise TypeError(f"string must be of type str|None, is type{type(string)}")
        if not isinstance(value, (int, float, str)) and value is not None:
            raise TypeError(f"value must be of type int|float|None, is type {type(value)}")
        if not isinstance(uom, str) and uom is not None:
            raise TypeError(f"uom must be of type str|None, is type {type(uom)}")

        if string is not None and (value is not None or uom is not None):
            raise ValueError("Only string or value and optional uom may be provided.")

        if string is not None:
            regex = r'(?P<value>\d+\.\d+|\.\d+|\d+)(?P<uom>[a-zA-Z%]*)'
            match = re.match(regex, string)
            if match is None:
                raise ValueError(f"Unable to find numeric value in input {string}")
            uom = match.group('uom').strip()
            uom = uom if uom != "" else None
            value = match.group('value')
        if isinstance(value, int) or isinstance(value, float): # If it's already an int or a float, no casting is done.
            pass # Nothing is required because it's already in the correct format
        elif value.lstrip("-").isdigit(): # Checks to see if the number can be casted to an int.
            value = int(value)
        else:
            value = float(value)
        logger.debug("%s, %s", value, uom)

        uom_list = [
            'ns', 'us', 'ms', 's', 'm', 'h', 'd',
            '%',
            'B', 'kB',  'MB',  'GB',  'TB',  'PB',  'EB',  'ZB',  'YB',
                 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB',
            'b', 'kb',  'mb',  'gb',  'tb',  'pb',  'eb',  'zb',  'yb',
                 'kib', 'mib', 'gib', 'tib', 'pib', 'eib', 'zib', 'yib'
            'nA', 'uA', 'mA', 'A', 'kA', 'MA', 'GA', 'TA', 'PA', 'EA', 'ZA', 'YA',
            'nO', 'uO', 'mO', 'O', 'kO', 'MO', 'GO', 'TO', 'PO', 'EO', 'ZO', 'YO',
            'nV', 'uV', 'mV', 'V', 'kV', 'MV', 'GV', 'TV', 'PV', 'EV', 'ZV', 'YV',
            'nW', 'uW', 'mW', 'W', 'kW', 'MW', 'GW', 'TW', 'PW', 'EW', 'ZW', 'YW',
            'nAs', 'uAs', 'mAs', 'As', 'kAs', 'MAs', 'GAs', 'TAs', 'PAs', 'EAs', 'ZAs', 'YAs',
            'nAm', 'uAm', 'mAm', 'Am', 'kAm', 'MAm', 'GAm', 'TAm', 'PAm', 'EAm', 'ZAm', 'YAm',
            'nAh', 'uAh', 'mAh', 'Ah', 'kAh', 'MAh', 'GAh', 'TAh', 'PAh', 'EAh', 'ZAh', 'YAh',
            'nWs', 'uWs', 'mWs', 'Ws', 'kWs', 'MWs', 'GWs', 'TWs', 'PWs', 'EWs', 'ZWs', 'YWs',
            'nWm', 'uWm', 'mWm', 'Wm', 'kWm', 'MWm', 'GWm', 'TWm', 'PWm', 'EWm', 'ZWm', 'YWm',
            'nWh', 'uWh', 'mWh', 'Wh', 'kWh', 'MWh', 'GWh', 'TWh', 'PWh', 'EWh', 'ZWh', 'YWh',
            'lm',
            'dBm',
            'ng', 'ug', 'mg', 'g', 'kg', 't',
            'C', 'F', 'K',
            'ml', 'l', 'hl',
            'c',
            "Hz", "hz"
        ]
        if uom is not None and uom not in uom_list:
            logger.debug("UOM: %s", uom)
            raise ValueError("Unit of measurement must be a uom supported by icinga2")

        if Quantity.isdatatype(uom):
            self.value = parse_bytes(value, 'B', True)
            self.uom = 'B'
        else:
            self.value = value
            self.uom = uom

    def _key(self) -> (float, str):
        return (self.value, self.uom)

    def __eq__(
        self,
        other: Quantity
    ) -> bool:
        if not isinstance(other, Quantity):
            raise TypeError(f"Quantity is not comparable with {type(other)}")
        else:
            return self._key() == other._key()


    def __lt__(
        self,
        other: Quantity
    ) -> bool:
        logger.debug(self)
        logger.debug(other)
        if not isinstance(other, Quantity):
            raise TypeError(f"Quantity is not comparable with {type(other)}")
        elif self.uom != other.uom:
            raise TypeError("Quantities with different units of measurement are not comparable.")
        else:
            logger.debug(type(self.value))
            logger.debug(self.value)
            logger.debug(type(other.value))
            logger.debug(other.value)
            return self.value < other.value

    def __neg__(
        self
    ) -> Quantity:
        return Quantity(value=-1*self.value, uom=self.uom)

    def __add__(
        self,
        other: Quantity
    ) -> Quantity:
        if not isinstance(other, Quantity):
            raise TypeError(f"Quantity is not comparable with {type(other)}")
        elif self.uom != other.uom:
            raise TypeError("Quantities with different units of measurement are not comparable.")
        else:
            return Quantity(value=self.value + other.value, uom=self.uom)

    def __sub__(
        self,
        other: Quantity
    ) -> Quantity:
        return Quantity(value=self.value + (-other.value), uom=self.uom)

    def __str__(self) -> str:
        uom = "" if self.uom is None else self.uom
        return f"{self.value}{uom}"

    def format(self) -> str:
        """
        Formats bytes in largest format

        Returns:
            str: Bytes formatted in largest denominator
        """
        if Quantity.isdatatype(self.uom):
            val, u = format_bytes(self.value)
            return f"{val}{u}B"
        return str(self)

    def __hash__(self) -> str:
        return hash(self._key())

    def __repr__(self) -> str:
        return f"Quantity({self.value}, {self.uom})"

class Threshold:
    """
    A threshold value
    """
    def __init__( # pylint: disable=too-many-arguments
        self,
        crit_high: Quantity,
        warn_high: Quantity,
        warn_low: Quantity = None,
        crit_low: Quantity = None,
        reverse: bool = False
    ) -> None:
        if not isinstance(crit_high, Quantity):
            raise TypeError(f"crit_high must be of type Quantity, is type {type(crit_high)}")
        if not isinstance(warn_high, Quantity):
            raise TypeError(f"warn_high must be of type Quantity, is type {type(warn_high)}")
        if warn_low is not None and not isinstance(warn_low, Quantity):
            raise TypeError(f"warn_low must be of type Quantity, is type {type(warn_low)}")
        if crit_low is not None and not isinstance(crit_low, Quantity):
            raise TypeError(f"crit_low must be of type Quantity, is type {type(crit_low)}")
        if not isinstance(reverse, bool):
            raise TypeError(f"reverse must be of type bool, is of type {type(reverse)}")
        self.crit_high = crit_high
        self.warn_high = warn_high
        self.warn_low = warn_low
        self.crit_low = crit_low
        self.reverse = reverse

    def __str__(
        self
    ) -> str:
        cr = self.crit_high if self.crit_low is None else ':'.join([self.crit_low, self.crit_high])
        wn = self.warn_high if self.warn_low is None else ':'.join([self.warn_low, self.warn_high])
        return f"Critical {cr}, Warning {wn}{', reversed' if self.reverse else ''}"

    def __repr__(
        self
    ) -> str:
        return f"Threshold(crit_high={self.crit_high},warn_high={self.warn_high},warn_low={self.warn_low}" \
        + f",crit_low={self.crit_low},reversed={self.reverse})"

    def _key(self) -> (Quantity, Quantity, Quantity|None, Quantity|None, bool):
        return (self.crit_high, self.warn_high, self.warn_low, self.crit_low, self.reverse)

    def __hash__(self) -> str:
        return hash(self._key())

    def __eq__(self, other: Threshold) -> bool:
        return self._key() == other._key()

    # pylint: disable=invalid-unary-operand-type
    # Mathematically, if approaching a positive threshold and checking value > crit, then going
    # down in the format of value < crit is the same as -value > -crit, therefore making both
    # numbers negative gives us a threshold approaching from above rather than from below.
    def critical(self, check_val: Quantity) -> bool:
        """
        Returns True if the value violates critical thresholds.

        Args:
            check_val (Quantity): Value to check against thresholds.

        Returns:
            bool: True if value violates critical thresholds, otherwise False.
        """
        if self.reverse:
            return -check_val >= -self.crit_high \
                or (self.crit_low is not None and -check_val <= -self.crit_low)
        return check_val >= self.crit_high \
            or (self.crit_low is not None and check_val <= self.crit_low)

    # The reasoning behind why warning excludes a value if it's critical is that we believe that
    # a critical value precludes a warning value.
    def warning(self, check_val: Quantity) -> bool:
        """
        Returns True if the value violates warning thresholds BUT NOT critical thresholds.

        Args:
            check_val (Quantity): Value to check against thresholds.

        Returns:
            bool: True if value violates warning thresholds BUT NOT critical thresholds, otherwise False.
        """
        if self.critical(check_val):
            return False
        if self.reverse:
            return -check_val >= -self.warn_high \
                or (self.warn_low is not None and -check_val <= -self.warn_low)
        return check_val >= self.warn_high \
            or (self.warn_low is not None and check_val <= self.warn_low)

    def ok(self, check_val: Quantity) -> bool:
        """
        Checks if a value is within normal thresholds

        Args:
            check_val (Quantity): Value to check against thresholds.

        Returns:
            bool: Value is within operating thresholds.
        """
        if self.reverse:
            return -check_val < -self.warn_high and (self.warn_low is None or -check_val > -self.warn_low)
        return check_val < self.warn_high and (self.warn_low is None or check_val > self.warn_low)
    # pylint: enable=invalid-unary-operand-type

    def unknown(self, check_val: Quantity) -> bool:
        """
        Checks if a value does not meet other check criteria

        Args:
            check_val (Quantity): Value to check against

        Returns:
            bool: returns True if not CRITICAL, WARNING, or OK, False otherwise.
        """
        return not (self.critical(check_val) or self.warning(check_val) or self.ok(check_val))

def parse_thresh_val(
    val: str
) -> (None | str, str):
    """
    Parses threshold values

    Args:
        val (str): a string in the format of hi or low:hi

    Returns:
        (str | None, str): thresholds indicate thresholds.
    """
    if ":" in val:
        n = val.split(":")
        return (n[0], n[1])
    return(None, val)

def parse_thresholds(
    params: dict=None, #pylint: disable=redefined-outer-name
    warn: str=None,
    crit: str=None,
    default: (str, str)=None,
    reverse: bool=False,
    ) -> Threshold:
    """
    Parses thresholds and verifies that at least one threshold is set.

    Args:
        params (dict): Dictionary of parameters that have been passed.
        warn (str): String representation of warning in ## or ##:## format.
        crit (str): String representation of critical in ## or ##:## format.
        default ((str, str)): String representations of warn, crit in ## or ##:## format.

    Raises:
        ValueError: Warning must be provided.
        ValueError: Critical must be provided.

    Returns:
        Threshold:
            A Threshold object respresenting the thresholds
    """
    # Verify thresholds.
    ## Parameters override warning and critical
    thresholds = [None, None, None, None] # Warn low, warn high, crit low, crit high.
    # Must have one of parameters, warning/crit, or default. Preference in that order.
    # The idea is to set the values via default, overwrite via warning/crit, and then
    # overwrite with params. If at the end of the function, either 1 or 3 is None, then error.
    # Overwrite with default parameters
    if default is not None:
        if default[0] is not None:
            thresholds[0:2] = parse_thresh_val(str(default[0]))
        if default[1] is not None:
            thresholds[2:4] = parse_thresh_val(str(default[1]))

    # Overwrite with warn/crit
    if warn is not None:
        thresholds[0:2] = parse_thresh_val(str(warn))
    if crit is not None:
        thresholds[2:4] = parse_thresh_val(str(crit))

    # Overwrite with params
    if params is not None:
        if 'warn' in params:
            thresholds[0:2] = parse_thresh_val(str(params['warn']))
        if 'crit' in params:
            thresholds[2:4] = parse_thresh_val(str(params['crit']))

    # Verify values all have the same unit of measurement.
    logger.debug("thresholds before parsing to quantity: %s", thresholds)
    thresholds = [Quantity(string=x) if x is not None else None for x in thresholds]
    logger.debug("after parsing to quantity: %s", thresholds)
    if not all(i is None or i.uom == thresholds[1].uom for i in thresholds):
        raise ValueError(f"Unit of measurement must be the same for all values. {thresholds}")

    # Verify that 1 and 3 have values
    if thresholds[1] is None:
        raise ValueError("Warning must be specified in default, warn flag, or parameters.")
    if thresholds[3] is None:
        raise ValueError("Critical must be specified in default, warn flag, or parameters.")

    if thresholds[1] >= thresholds[3]:
        raise ValueError("Critical value must be higher than warning value.")

    return Threshold(
        thresholds[3],
        thresholds[1],
        thresholds[0],
        thresholds[2],
        reverse
    )

def parse_stat(
        check_quant: Quantity,
        threshold: Threshold
    ) -> int:
    """
    Checks numerical values against thresholds

    Args:
        check_quant (Quantity): A Quantity object representing the value to check against.
        thresholds (Threshold): A Threshold value.

    Returns:
        int: Numeric icinga status. 0 OK, 1 WARN, 2 CRIT, 3 UNKNOWN
    """

    if threshold.critical(check_quant):
        return 2
    if threshold.warning(check_quant):
        return 1
    if threshold.ok(check_quant):
        return 0
    return 3

def parse_bytes(
    value: str | float,
    to: str,
    numeric: bool = False,
    byte_size: int = 1024
) -> str | float:
    """
    Parses byte strings to different formats

    Args:
        value (str|float|int): The starting byte string
        to (str): The byte unit to transform to
        numeric (bool, optional): Discard byte suffix if true. Defaults to False.
        byte_size (int, optional): Size of byte. Defaults to 1024.

    Raises:
        ValueError: If provided an unknown byte suffix
        ValueError: If unable to find a match in string.

    Returns:
        str|float: Byte string if numeric is false, else raw value
    """
    if not isinstance(value, (str, float, int)):
        raise TypeError(f"value must be of type str|float|int, is of type {type(value)}")
    if not isinstance(to, str):
        raise TypeError(f"to must be of type str, is of type {type(to)}")
    if not isinstance(numeric, bool):
        raise TypeError(f"numeric must be of type bool, is of type {type(numeric)}")
    if not isinstance(byte_size, int):
        raise TypeError(f"byte_size must be of type int, is of type {type(byte_size)}")
    # Table of powers. 1024**pow gives a multiplier.

    if to not in BYTE_TABLE:
        st = ', '.join(BYTE_TABLE)
        raise ValueError(f"to must be one of {st}, is {to}")

    # Captures value and uom
    regex = r'(?P<value>\d+\.\d+|\.\d+|\d+)(?P<uom>\w*)'
    logger.debug(str(value))
    match = re.match(regex, str(value))
    logger.debug(match)

    # If there's no match, this is invalid
    if match is None:
        raise ValueError(f"Unable to find valid byte-string in {value}")
    val = float(match.group('value')) # Grab the value from the match
    start_unit = match.group('uom') # Grab the uom from the match
    start_unit = 'B' if start_unit == '' else start_unit # Base of B
    start_pow = BYTE_TABLE.index(start_unit.upper()) # Find the exponent for the starting base
    to_pow = BYTE_TABLE.index(to.upper()) # Find the exponent for the ending base
    diff = start_pow - to_pow # Find the difference.
    # Multiplying byte_size to the exponent difference will give us a multiplier for value
    # that will accurately scale the value.
    mult = byte_size**diff

    # Take into account bit to byte and vice versa
    if to[-1].isupper() and start_unit[-1].islower():
        mult *= 8
    elif to[-1].islower() and start_unit[-1].isupper():
        mult *= 1/8

    # Multiply it out
    res = val * mult

    if numeric:
        return res
    return f"{res}{to}"

def format_bytes(bytes_in: float, base: int = 1024) -> (float, str):
    """
    Returns a tuple with bytes converted to largest denominator.

    Args:
        bytes_in (int): A number of bytes to convert to largest denominator.
        base (int, optional): The byte base if not 1024. (Defaults to 1024)

    Raises:
        TypeError: If bytes_in is not int
        TypeError: If base is not int

    Returns:
        (float, str): A tuple consisting of a float rounded to 3 places and a denomination.
    """
    if not isinstance(bytes_in, (int, float)):
        raise TypeError(f"bytes_in must be of type int|float, is of type {type(bytes_in)}")
    if not isinstance(base, int):
        raise TypeError(f"base must be of type int, is of type {type(base)}")
    power = int(math.floor(math.log(bytes_in, base)))
    lg_denom = BYTE_TABLE[power]
    fin = round(bytes_in/(base**power), 3)
    return (fin, lg_denom)
