# ***************************************************************************
#  IIIIII NNN  NNN  Copyright (C) 2023 Innovative Networks, Inc.
#    II    NNN  N   All Rights Reserved. Any redistribution or reproduction
#    II    N NN N   of part or all of the content of this program in any form
#    II    N  NNN   without expressed written consent of the copyright holder
#  IIIIII NNN  NNN  is strictly prohibited.  Please contact admins@in-kc.com
#   Be Innovative.  for additional information.
# ***************************************************************************
#  inmon_utils.py
#  Author - Ian Perry <iperry@indigex.com>
#
#  Purpose:  Utility functionality for Python INmon scripts
#
#  Version History:
#       2023.04.05 - Modified to add INmon passive check API functionality.
# ***************************************************************************
# pylint: disable=missing-module-docstring
from __future__ import annotations
import sys
import importlib
import time
import re
import socket
import logging
import traceback
import subprocess
from types import TracebackType
import inmon_quant as quant

logger = logging.getLogger(__name__)

try:
    import requests
except Exception: # pylint: disable=broad-exception-caught
    subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"])

def raise_exception_as_3(
    exc_type: Exception, exc_value: int, tb: TracebackType
) -> None:
    """Catches exception, prints exception, then exits as status 3 (unknown in inmon)"""

    traceback.print_exception(exc_type, exc_value, tb)
    sys.exit(3)

# Activates the above only if being run non-interactively.
if not hasattr(sys, "ps1"):
    sys.excepthook = raise_exception_as_3

def ensure_module(module, pip_name=None):
    """
    Attempts to import a module, and if it fails, installs and imports.

    Args:
        module (str): Name of the module
        pip_name (str, optional): The pip name for the module, if different. Defaults to None.
    """
    if not pip_name:
        pip_name = module
    try:
        importlib.import_module(module)
    except Exception: # pylint: disable=broad-exception-caught

        subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name])
        time.sleep(3)
        importlib.import_module(module)

def create_request_session(
    auth = None,
    cert = None
) -> requests.Session:
    """Returns a session with correct headers and authorization credentials for INmon"""

    # Checks to make sure we get either auth or cert, not both.
    if (auth is None and cert is None) or (auth is not None and cert is not None):
        raise ValueError("Expected either auth or cert args")

    # Create requests session.
    s = requests.Session()

    # Username and password authentication
    if auth is not None:
        # Make sure it's a tuple
        if not isinstance(auth, tuple):
            raise ValueError(f"Expected tuple, recieved f{type(auth)}")

        # Make sure both items are strings and not empty
        elif not all(isinstance(v, str) and v != "" for v in auth):
            raise ValueError(
                f"Expected tuple of type (str, str), recieved ({type(auth[0])}, {type(auth[1])})"
            )

        # Set auth
        else:
            s.auth = auth

    # Certificate authentication
    elif cert is not None:
        # Make sure it's a tuple
        if not isinstance(cert, tuple):
            raise ValueError(f"Expected tuple, recieved f{type(cert)}")

        # Make sure both items are string
        elif not all(isinstance(v, str) and v != "" for v in cert):
            raise ValueError(
                f"Expected tuple of type (str, str), recieved ({type(cert[0])}, {type(cert[1])})"
            )

        # Apply cert
        else:
            s.cert = cert

    s.headers.update({"Accept-Encoding": "gzip", "Accept": "application/json"})
    s.verify = False
    return s

def is_valid_hostname(hostname: str) -> bool:
    """Verifies that a string is a valid hostname"""
    if hostname[-1] == ".":
        # strip exactly one dot from the right, if present
        hostname = hostname[:-1]
    if len(hostname) > 253:
        return False

    labels = hostname.split(".")

    # the TLD must be not all-numeric
    if re.match(r"[0-9]+$", labels[-1]):
        return False

    allowed = re.compile(r"(?!-)[a-z0-9-]{1,63}(?<!-)$", re.IGNORECASE)
    return all(allowed.match(label) for label in labels)

def is_valid_ip(ip: str) -> bool:
    """Verifies that a string is a valid IP address"""
    try:
        socket.inet_aton(ip)
        return True
    except Exception: # pylint: disable=broad-exception-caught
        return False

def dict_merge(dct, merge_dct):
    """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    :param dct: dict onto which the merge is executed
    :param merge_dct: dct merged into dct
    :return: None
    """
    for k, v in iter(merge_dct.items()): # pylint: disable=unused-variable
        if (
            k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)
        ):  # noqa
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]

def download_file(url, out_file):
    """Downloads a file

    Args:
        url (string): URL of file to download
        out_file (string): Location to write file
    """
    response = requests.get(url, allow_redirects=True) # pylint: disable=missing-timeout # nosec
    with open(out_file, "wb") as f:
        f.write(response.content)

def version_check(ver: str) -> bool:
    """
    Returns a bool based on whether python version meets or exceeds the string
    or does not.
    :param str: version in format major.minor.micro
    """
    vers = sys.version_info
    spl = ver.split('.')
    if (len(spl) >= 1) and (int(spl[0]) > vers.major):
        return False
    if (len(spl) >= 2) and (int(spl[1]) > vers.minor):
        return False
    if (len(spl) >= 3) and (int(spl[2]) > vers.micro):
        return False
    else:
        return True

def csv(arg):
    """Returns a list from a 'csv' input argument

    Args:
        arg (string): comma separated list

    Returns:
        list: list of strings
    """
    return [x.strip() for x in arg.split(',')]

def handle_exit(check: CheckResult | MultiActiveCheck) -> None:
    """
    Prints a check results and exits with the correct result

    Args:
        check (CheckResult | MultiActiveCheck): A check or multicheck constituting the check.
    """
    print(check)
    sys.exit(check.return_val)


class bidict(dict): # pylint: disable=invalid-name
    """
    A bidirectional dictionary implementation.

    Attributes:
    self: a dictionary
    inverse: the invers of self
    """
    def __init__(self, *args, **kwargs):
        super(bidict, self).__init__(*args, **kwargs)
        self.inverse = {}
        for key, value in self.items():
            self.inverse.setdefault(value, []).append(key)

    def __setitem__(self, key, value):
        if key in self:
            self.inverse[self[key]].remove(key)
        super(bidict, self).__setitem__(key, value)
        self.inverse.setdefault(value, []).append(key)

    def __delitem__(self, key):
        self.inverse.setdefault(self[key], []).remove(key)
        if self[key] in self.inverse and not self.inverse[self[key]]:
            del self.inverse[self[key]]
        super(bidict, self).__delitem__(key)

class PerfData:
    """
    A representation of icinga performance data.

    Attributes
    ----------
    label : str
        The item label, must be unique in a CheckData object.

    measure : Quantity
        The quantity being measured.

    thresholds : Threshold, Optional
        The thresholds that the value was checked against.

    minimum : float, Optional
        The minimum possible value that the data can return.

    maximum : float, Optional
        The maximum possible value that the data can return.
    """

    def __init__(
        self,
        label: str,
        measure: quant.Quantity,
        thresholds: quant.Threshold = None,
        minimum: float = None, # min and max are reserved
        maximum: float = None
    ) -> None:
        # pylint disable=syntax-error
        if not isinstance(label, str):
            raise TypeError(f"label must be of type str, is type {type(label)}")

        if not isinstance(measure, quant.Quantity):
            raise TypeError(f"measure must be of type Quantity, is type {type(measure)}")

        if thresholds is not None and not isinstance(thresholds, quant.Threshold):
            raise TypeError(f"thresholds must be of type Threshold, is type {type(thresholds)}")

        if minimum is not None and not isinstance(minimum, (int, float)):
            raise TypeError(f"minimum must be of type int or float, is type {type(minimum)}")

        if maximum is not None and not isinstance(maximum, (int, float)):
            raise TypeError(f"maximum must be of type int or float, is type {type(maximum)}")
        # pylint enable=syntax-error

        if minimum is not None and maximum is not None and minimum > maximum:
            raise ValueError("maximum must be >= minimum")

        self.label = label
        self.measure = measure
        self.thresholds = thresholds
        self.minimum = minimum if minimum is not None else ""
        self.maximum = maximum if maximum is not None else ""

    def __str__(self) -> str:
        t = f"'{self.label}'={self.measure};" \
            + f"{self.thresholds.warn_high.value if self.thresholds is not None else ''}"\
            + f";{self.thresholds.crit_high.value if self.thresholds is not None else ''};" \
            + f"{self.minimum};{self.maximum}"
        t = t.strip()
        t = t.rstrip(';')
        return t

    def __eq__(self, other) -> bool:
        return(self.label.lower() == other.label.lower())

    def __hash__(self) -> str:
        return hash(self.label)

    def p(self, val: str|None) -> str:
        """
        Prints the word None if the value is None.
        """
        return "None" if val is None else val

    def __repr__(self) -> str:
        return f"""
    PerfData
        Label '{self.label}'
        Value {self.measure.value} UOM {self.measure.uom}
        Warning {self.p(self.thresholds.warn_high)}
        Critical {self.p(self.thresholds.crit_high)}
        Minimum {"None" if self.minimum == "" else self.minimum}
        Maximum {"None" if self.maximum == "" else self.maximum}
    """

class CheckResult:
    """
    A representation of a check result

    Attributes:

    return_val : int
        A return value of 0-3 that represents a status in icinga.

    return_text : str
        The output text of the check, seen up front and should have the most important data.

    cont_output : str
        This is continued output that is not immediately shown but may contain helpful data.

    perf_data : Set[PerfData]
        A set of performance data. This must be added via the add_perfdata() method.
    """
    return_values = bidict({'OK': 0, 'WARNING': 1, 'CRITICAL': 2, 'UNKNOWN': 3})
    def __init__(
        self,
        return_val: int,
        return_text: str,
        cont_output: str = None
    ) -> None:
        if not isinstance(return_val, int):
            raise TypeError(f"return_val must be of type int, is of type {type(return_val)}")
        elif not 0 <= return_val <= 3:
            raise ValueError("return_val must be between 0 and 3 exclusive")

        if not isinstance(return_text, str):
            raise ValueError(f"return_text must be of type str, is of type {type(return_text)}")

        self.return_val = return_val
        self.return_text = return_text
        self.cont_output = cont_output if cont_output is not None else ""
        self.perf_data = set()

    def process(self):
        """
        Processes check for INmon purposes.
        """
        print(self)
        sys.exit(self.return_val)

    def add_perfdata(self, perfdata: PerfData) -> None:
        """
        Add performance data object to check

        Args:
            perfdata (PerfData): Performance data you want to add
        """
        self.perf_data.add(perfdata)

    def __str__(self) -> str:
        status = self.return_values.inverse[self.return_val][0]
        if self.cont_output != "":
            self.cont_output = f"\n{self.cont_output}"
        perf_string = ' '.join([str(x) for x in self.perf_data]).strip()
        return_text = f"[{status}] - {self.return_text}"
        if len(perf_string) > 0 or len(self.cont_output) > 0:
            return_text = f"{return_text} |"
        if len(perf_string) > 0:
            return_text = f"{return_text} {perf_string}"
        if len(self.cont_output) > 0:
            return_text = f"{return_text}{self.cont_output}"
        return return_text

    def __repr__(self) -> str:
        # Can't have in f-string prior to 3.12 due to escape character
        perf_string = '\n'.join([str(x) for x in self.perf_data])
        return f"""
    Return value {self.return_val} ({self.return_values.inverse[self.return_val]})
    Output:
        {self.return_text}
    Continued Output:
        {self.cont_output}
    Performance Data:
        {perf_string}
    """

    def __eq__(self, other) -> bool:
        return((self.return_val == other.return_val) and (self.return_text == other.return_text))

    def __hash__(self) -> str:
        return hash(str(self))

    def __lt__(self, other) -> bool:
        return str(self) < str(other)

class MultiActiveCheck:
    """
    A type representing multiple checks at once for easy returns.
    """
    return_values = bidict({'OK': 0, 'WARNING': 1, 'CRITICAL': 2, 'UNKNOWN': 3})
    def __init__(
        self
    ):
        # Each one in a set helps for output formatting.
        self.checks = {
            "OK": set(),
            "WARNING": set(),
            "CRITICAL": set(),
            "UNKNOWN": set()
        }
        self.return_val = 0
        self.not_OK = False # pylint: disable=invalid-name
        self.index = 0

    def process(self):
        """
        Processes check for INmon purposes.
        """
        print(self)
        sys.exit(self.return_val)

    def append(self, check: CheckResult | MultiActiveCheck) -> None:
        """
        Append a check onto the set of checks.

        Args:
            check (CheckResult): Check you want to add.
        """
        if isinstance(check, CheckResult):
            # Full output is only created on string conversion.
            self.return_val = max(self.return_val, check.return_val)
            self.checks[self.return_values.inverse[check.return_val][0]].add(check)
        elif isinstance(check, MultiActiveCheck):
            for _, members in check.checks.items():
                for ea in members:
                    self.append(ea)

    def __iter__(self) -> MultiActiveCheck:
        return self

    def __next__(self):
        if self.index < len(self.checks):
            self.index += 1
            return self.checks[self.index]
        raise StopIteration


    def __str__(self) -> str:
        return_text = ""
        extra_text = ""
        perf_data = ""
        for check in sorted(self.checks["WARNING"]):
            if (len(return_text) + len(check.return_text)) < 90:
                return_text = f"{check.return_text}. {return_text}"
            else:
                extra_text = f"{extra_text}\n[{''.join(self.return_values.inverse[check.return_val])}] - {check.return_text}"
            perf_data = f"{perf_data} {' '.join([str(x) for x in check.perf_data])}"

        for check in sorted(self.checks["CRITICAL"]):
            if (len(return_text) + len(check.return_text)) < 90:
                return_text = f"{check.return_text}. {return_text}"
            else:
                extra_text = f"{extra_text}\n[{''.join(self.return_values.inverse[check.return_val])}] - {check.return_text}"
            perf_data = f"{perf_data} {' '.join([str(x) for x in check.perf_data])}"

        for check in sorted(self.checks["UNKNOWN"]):
            if (len(return_text) + len(check.return_text)) < 90:
                return_text = f"{check.return_text}. {return_text}"
            else:
                extra_text = f"{extra_text}\n[{''.join(self.return_values.inverse[check.return_val])}] - {check.return_text}"
            perf_data = f"{perf_data} {' '.join([str(x) for x in check.perf_data])}"

        for check in sorted(self.checks["OK"]):
            if (len(return_text) + len(check.return_text)) < 90:
                return_text += f"{check.return_text}. "
            else:
                extra_text = f"{extra_text}\n[{''.join(self.return_values.inverse[check.return_val])}] - {check.return_text}"
            perf_data = f"{perf_data} {' '.join([str(x) for x in check.perf_data])}"

        if len(extra_text) > 1:
            return_text += "See add'l output for more."

        perf_data = re.sub(' +', ' ', perf_data)
        return_text = f"[{self.return_values.inverse[self.return_val][0]}] - {return_text}"
        return_text = return_text.strip()
        if len(perf_data.strip()) > 0 or len(extra_text.strip()) > 0:
            return_text = f"{return_text} |"
        if len(perf_data.strip()) > 0:
            return_text = f"{return_text} {perf_data.strip()}"
        if len(extra_text.strip()) > 0:
            return_text = f"{return_text}\n{extra_text.strip()}"
        return return_text

    def __repr__(self) -> str:
        return str(self)

class PassiveCheckResult: # pylint: disable=missing-class-docstring
    def __init__(
        self,
        parent: str,
        host: str,
        service: str,
        exit_status: int,
        plugin_output: str,
        performance_data: str = None,
        check_source: str = None,
    ) -> None:
        if not isinstance(parent, str):
            raise TypeError(f"parent must be of type str, is type {type(parent)}")
        elif parent == "":
            raise ValueError("parent must not be empty string")
        elif not (is_valid_hostname(parent) or (is_valid_ip(parent))):
            raise ValueError("parent must be a valid hostname or IP address")

        if not isinstance(host, str):
            raise TypeError(f"host must be of type str, is type {type(host)}")
        elif host == "":
            raise ValueError("host must not be empty string")

        if not isinstance(service, str):
            raise TypeError(f"service must be of type str, is type {type(service)}")
        elif service == "":
            raise ValueError("service must not be empty string")

        if not isinstance(exit_status, int):
            raise TypeError(
                f"exit_status must be of type int, is type {type(exit_status)}"
            )
        elif not 0 <= exit_status <= 3:
            raise ValueError("exit_status must be between 0 and 3 inclusive")

        m = re.compile(r"^\[(?:OK?|WARN(?:ING)?|CRIT(?:ICAL)?|UNKN(?:OWN)?)\].*")
        if not isinstance(plugin_output, str):
            raise TypeError(
                f"plugin_output must be of type str, is type {type(plugin_output)}"
            )
        elif plugin_output == "":
            raise ValueError("plugin_output must not be empty string")
        elif m.match(plugin_output) is None:
            if exit_status == 0:
                head = "[OK] - "
            elif exit_status == 1:
                head = "[WARN] - "
            elif exit_status == 2:
                head = "[CRIT] - "
            elif exit_status == 3:
                head = "[UNKN] - "
            else:
                raise ValueError(
                    "plugin_output must start with a status string surrounded by brackets, e.g. "\
                    + "[OK], [WARN], [CRIT], [UNKN]"
                )
            plugin_output = head + plugin_output

        if performance_data is not None:
            if not isinstance(performance_data, str):
                raise TypeError(
                    f"performance_data must be of type str, is type {type(performance_data)}"
                )
            elif performance_data == "":
                raise ValueError("performance_data must not be empty string")

        if check_source is not None:
            if not isinstance(check_source, str):
                raise TypeError(
                    f"check_source must be of type str, is type {type(check_source)}"
                )
            elif check_source == "":
                raise ValueError("check_source must not be empty string")

        self.parent = parent
        self.host = host
        self.service = service
        self.exit_status = exit_status
        self.plugin_output = plugin_output
        self.performance_data = performance_data
        self.check_source = check_source



    # Representation of object as both string and object
    def __str__(self) -> str:
        t = f"{self.host}!{self.service} exited with status {self.exit_status}" \
            + f"{self.exit_status} with output {self.plugin_output}" \
            + f"to satellite {self.parent}, performance data: {self.performance_data}"\
            + f"and check_source {self.check_source}"
        return t

    def __repr__(self) -> str:
        t = f'PassiveCheckResult("{self.parent}", "{self.host}", "{self.service}", {self.exit_status}, '\
            + f'"{self.plugin_output}"{"" if self.performance_data is None else ", " + self.performance_data}'\
            + f'{"" if self.check_source is None else ", " + self.check_source})'
        return t

    def submit(self, session: requests.Session = None) -> requests.Response:
        """Submit this passive check result to INmon"""

        # If submitting multiple passive checks, a session should be created separately and passed to this function.
        # This allows connection to stay open until it's finished being used.
        # In this case, session should be closed manually.
        if session is None:
            close_session = True
            session = create_request_session()
        else:
            close_session = False

        # Do type checking for session.
        if not isinstance(session, requests.Session):
            raise TypeError(
                f"session must be of type requests.Session, is type {type(session)}"
            )

        # Assign data to a dictionary in order to post to API
        data = {
            "type": "Service",
            "filter": f'host.name=="{self.host}" && service.name=="{self.service}"',
            "exit_status": self.exit_status,
            "plugin_output": self.plugin_output,
        }

        # Assign these values to a dictionary if they're there. Type and value checking will have been done
        # at object creation.
        if self.performance_data is not None:
            data["performance_data"] = self.performance_data
        if self.check_source is not None:
            data["check_source"] = self.check_source

        # Post to the parent and return the result, then close the session if this is a single passive check submission.
        result = session.post(
            f"https://{self.parent}:5665/v1/actions/process-check-result",
            json=data,
            verify=False,
        )
        if close_session:
            session.close()
        return result
