#!/bin/env python3
#*******************************************************************************
# IIIIII NNN    N  (C) 2023 INDIGEX All Rights Reserved
#   II   NN NN  N  Any redistribution or reproduction of part or all of the
#   II   NN  NN N  content in any form is strictly prohibited.
# IIIIII NN    NN  Please contact admins@indigex.com for additional information.
#*******************************************************************************
#---------------------------------------------------------------------
# check_network
#
#
# Ian Perry [iperry]
# Innovative Networks, Inc.
# 
#
# Purpose:
# Check Network Statistics
# Report on TX & RX - Packets, Bytes, and Errors
#
# 2023-03-08 - Creation
#
#---------------------------------------------------------------------
try:
    from inmon_utils import *
except:
    import sys
    sys.exit(3)
import argparse
import subprocess
import os
import json
import time
import sys
import logging
import pprint

check_file_location = "/var/spool/inmon"

# Create argument parser itself
parser = argparse.ArgumentParser(description="Collects interface statistics")
                                 
# Creates verbose, host, and client flags.
# NOTE: -t --target is used in place of -h --hostname due to the fact that
# ... argparse reserves -h for --help.
parser.add_argument('-d', dest='device', help='Interface device name', required=True)
parser.add_argument('-w', dest='warn',   help='Warning trigger level (RX packets per sec/TX packets per sec/Err per sec)', default='1000/2000/10')
parser.add_argument('-c', dest='crit',   help='Critical trigger level (RX packets per sec/TX packets per sec/Err per sec)', default='1000/2000/15')
parser.add_argument('-D', dest='debug',  help='Enable debugging output', action="store_true")

args = parser.parse_args()

# Set up debug
if args.debug:
    log_level = getattr(logging, "DEBUG")
else:
    log_level = getattr(logging, "INFO")

# Set up logging
logging.basicConfig(level=log_level)
logger = logging.getLogger("check_network.py")

warn = args.warn.split('/')
if len(warn) < 3:
    print("Please provide a warning in the format RX packets per sec/TX packets per sec/Err per sec e.g. 1000/2000/10")
    sys.exit(3)
crit = args.crit.split('/')
if len(crit) < 3:
    print("Please provide a critical in the format RX packets per sec/TX packets per sec/Err per sec e.g. 1000/2000/15")
    sys.exit(3)

thresholds = {
    "rx": {
        "warn": int(warn[0]),
        "crit": int(crit[0])
    },
    "tx": {
        "warn": int(warn[1]),
        "crit": int(crit[1]),
    },
    "err": {
        "warn": int(warn[2]),
        "crit": int(crit[2])
    }
}

# Get command results for particular interface
logger.debug(f"Getting interface statistics for {args.device}")
command = f"/usr/bin/ip -s link show {args.device}"
results = subprocess.run(command, shell=True, capture_output=True)

if results.returncode != 0:
    print(results.stderr.decode())
    sys.exit(3)

output = results.stdout.decode()
output_lines = output.split("\n")

for i in range(len(output_lines)):
    line = output_lines[i].strip()
    if line.startswith('RX:'):
        rx_labels = line.replace('RX: ', '').split()
        rx_values = output_lines[i+1].strip().split()
        rx_dict = {rx_labels[j]: rx_values[j] for j in range(len(rx_labels))}
    if line.startswith('TX:'):
        tx_labels = line.replace('TX: ', '').split()
        tx_values = output_lines[i+1].strip().split()
        tx_dict = {tx_labels[j]: tx_values[j] for j in range(len(tx_labels))}

current_stats = {"time": time.time(), "rx": rx_dict, "tx": tx_dict}

check_file = os.path.join(check_file_location, f"{args.device}.netstats.json")

if not os.path.exists(check_file_location):
    os.mkdir(check_file_location)

if os.path.exists(check_file):
    with open(check_file, 'r') as f: 
        previous_stats = json.load(f)
else:
    with open(check_file, "w") as f:
        f.write(json.dumps(current_stats, indent=4))
    print("File does not exist, writing...")
    sys.exit(3)

with open('/proc/uptime', 'r') as f:
    uptime = float(f.readline().split()[0])

seconds_since_modification = current_stats["time"] - previous_stats["time"]

if (seconds_since_modification > uptime):
    logger.debug("Machine has rebooted since last check. Writing current device info and exiting...")
    with open(check_file, "w") as f:
        f.write(json.dumps(current_stats, indent=4))
    sys.exit(2)

# Determine 32 vs 64 bit architecture
if (sys.maxsize > 2**32):
    counter_rollover = (2**64)-1
else:
    counter_rollover = (2**32)-1

time_since_last_check = int(current_stats["time"]) - int(previous_stats["time"])


return_stat = 0
return_text = ""
return_data = ""
stats = {
    "rx": {},
    "tx": {}
}

logger.debug("Current stats:")
logger.debug(pprint.pformat(current_stats, indent=2, sort_dicts=False))
logger.debug("Previous stats:")
logger.debug(pprint.pformat(previous_stats, indent=2, sort_dicts=False))

# Calculate programatically.
for x_type in ["rx", "tx"]:
    for label in ["bytes", "packets", "errors"]:

        # Calculate counter since last check.
        if (current_stats[x_type][label] < previous_stats[x_type][label]):
            logger.debug(f"The {x_type.upper()} {label.capitalize()} counter has rolled over.")
            stats[x_type][label] = counter_rollover - int(previous_stats[x_type][label])
            stats[x_type][label] = int(stats[x_type][label]) + int(current_stats[x_type][label])
        else:
            stats[x_type][label] = int(current_stats[x_type][label]) - int(previous_stats[x_type][label])

        logger.debug(f"{x_type.upper()} {label.capitalize()} since last check: {stats[x_type][label]}.")

        # Calculate counter per second.
        stats[x_type][label] = int(stats[x_type][label] / time_since_last_check)
        logger.debug(f"{x_type.upper()} {label.capitalize()}/sec: {stats[x_type][label]}.")

        # Add to return data.
        return_data += f"{x_type.lower()}_{label.lower()}={stats[x_type][label]};; "

# Check the return values against warn/crit.

if stats["rx"]["packets"] > thresholds["rx"]["crit"]:
    return_stat = max(return_stat, 2)
    return_text += f"RX Packets CRITICAL, {stats['rx']['packets']}/s; "
elif stats["rx"]["packets"] > thresholds["rx"]["warn"]:
    return_stat = max(return_stat, 1)
    return_text += f"RX Packets WARNING, {stats['rx']['packets']}/s; "
else:
    return_text += f"RX Packets OK, {stats['rx']['packets']}/s; "

if stats["tx"]["packets"] > thresholds["tx"]["crit"]:
    return_stat = max(return_stat, 2)
    return_text += f"TX Packets CRITICAL, {stats['tx']['packets']}/s; "
elif stats["tx"]["packets"] > thresholds["tx"]["warn"]:
    return_stat = max(return_stat, 1)
    return_text += f"TX Packets WARNING, {stats['tx']['packets']}/s; "
else:
    return_text += f"TX Packets OK, {stats['tx']['packets']}/s; "

if max(stats["rx"]["errors"], stats["tx"]["errors"]) > thresholds["err"]["crit"]:
    return_stat = max(return_stat, 2)
    return_text += f"Errors CRITICAL, RX {stats['rx']['errors']}/s, TX {stats['tx']['errors']}/s "
elif max(stats["rx"]["errors"], stats["tx"]["errors"]) > thresholds["err"]["crit"]:
    return_stat = max(return_stat, 1)
    return_text += f"Errors WARNING, RX {stats['rx']['errors']}/s, TX {stats['tx']['errors']}/s "
else:
    return_text += f"Errors OK, RX {stats['rx']['errors']}/s, TX {stats['tx']['errors']}/s "

if return_stat == 2:
    print(f"[CRITICAL] - {return_text} | {return_data}")
if return_stat == 1:
    print(f"[WARNING] - {return_text} | {return_data}")
if return_stat == 0:
    print(f"[OK] - {return_text} | {return_data}")

sys.exit(return_stat)
