#!/usr/bin/env python3
#*******************************************************************************
# IIIIII NNN    N  (C) 2023 INDIGEX All Rights Reserved
#   II   NN NN  N  Any redistribution or reproduction of part or all of the
#   II   NN  NN N  content in any form is strictly prohibited.
# IIIIII NN    NN  Please contact admins@indigex.com for additional information.
#*******************************************************************************
#---------------------------------------------------------------------
# check_iostat_quantity.py
#
#
# Ian Perry [iperry]
# Innovative Networks, Inc.
# 
#
# Purpose:
# Check Disk Reads/Writes
# Report on Reads and Writes
#
# 2023-08-02 - Creation
#
#---------------------------------------------------------------------
try:
    from inmon_utils import *
except:
    import sys
    sys.exit(3)
import argparse
import os
import json
import time
import sys
import logging
import pprint

ensure_module("psutil")
import psutil

# Location where previous check's data will be stored.
check_file_location = "/var/spool/inmon"

# Create parser and arguments, parse arguments.
parser = argparse.ArgumentParser(description="Collect disk read/write quantity statistics")

parser.add_argument('-d', dest='device', help='Interface device name', required=True)
parser.add_argument('-w', dest='warn',   help='Warning trigger level (Reads per second,Writes per second)', default='1000,2000')
parser.add_argument('-c', dest='crit',   help='Critical trigger level (Reads per second,Writes per second)', default='1000,2000')
parser.add_argument('-D', dest='debug',  help='Enable debugging output', action="store_true")

args = parser.parse_args()

# Set up debug
if args.debug:
    log_level = getattr(logging, "DEBUG")
else:
    log_level = getattr(logging, "INFO")

# Set up logging
logging.basicConfig(level=log_level)
logger = logging.getLogger("check_iostat_quantity.py")

# Grab warn and crit levels from args
warn = args.warn.split(',')
if len(warn) < 2:
    print("Please provide a warning in the format Reads per sec,Writes per sec e.g. 1000,2000")
    sys.exit(3)

crit = args.crit.split(',')
if len(crit) < 2:
    print("Please provide a warning in the format Reads per sec,Writes per sec e.g. 1000,2000")
    sys.exit(3)

# Parse thresholds into single variable for later programmatic use.
thresholds = {
    "reads": {
        "warn": int(warn[0]),
        "crit": int(crit[0])
    },
    "writes": {
        "warn": int(warn[1]),
        "crit": int(crit[1]),
    }
}

# No option to get only data for just one disk, have to get all of it.
logger.debug(f"Getting interface statistics for all devices.")
data = psutil.disk_io_counters(perdisk=True)

# Return error if the disk doesn't exist.
if args.device not in data.keys():
    print(f"[UNKN] - {args.device} is not an available device. List of devices: {', '.join(data.keys())}")
    sys.exit(3)

# Get values into single variable for parallelism.
current_stats = {"time": time.time(), "reads": data[args.device].read_count, "writes": data[args.device].write_count}

# Location of actual check file
check_file = os.path.join(check_file_location, f"{args.device}.diskq.json")

# Create the folder if it doesn't exist
if not os.path.exists(check_file_location):
    os.mkdir(check_file_location)

# Handle file not existing, write to file and return unknown.
if os.path.exists(check_file):
    with open(check_file, 'r') as f: 
        previous_stats = json.load(f)
else:
    with open(check_file, "w") as f:
        f.write(json.dumps(current_stats, indent=4))
    print("File does not exist, writing...")
    sys.exit(3)

# Grab uptime for rollover purposes
with open('/proc/uptime', 'r') as f:
    uptime = float(f.readline().split()[0])

seconds_since_modification = current_stats["time"] - previous_stats["time"]

# Process in case of reboot
if (seconds_since_modification > uptime):
    logger.debug("Machine has rebooted since last check. Writing current device info and exiting...")
    with open(check_file, "w") as f:
        f.write(json.dumps(current_stats, indent=4))
    sys.exit(2)

# Determine 32 vs 64 bit architecture
if (sys.maxsize > 2**32):
    counter_rollover = (2**64)-1
else:
    counter_rollover = (2**32)-1

# Grab time in between
time_since_last_check = int(current_stats["time"]) - int(previous_stats["time"])

# Initialize variables
return_stat = 0
return_text = ""
return_data = ""
stats = {
    "reads": 0,
    "writes": 0
}

# sort_dicts was introduced in 3.7
logger.debug("Current stats:")
if version_check('3.7'):
    logger.debug(pprint.pformat(current_stats, indent=2, sort_dicts=False))
else:
    logger.debug(pprint.pformat(current_stats, indent=2))
logger.debug("Previous stats:")
if version_check('3.7'):
    logger.debug(pprint.pformat(previous_stats, indent=2, sort_dicts=False))
else:
    logger.debug(pprint.pformat(previous_stats, indent=2))

# Do this for both reads and writes.
for x_type in ["reads", "writes"]:
    # Handle the case of counter rollover
    if (current_stats[x_type] < previous_stats[x_type]):
        logger.debug(f"The {x_type.upper()} counter has rolled over.")
        stats[x_type] = counter_rollover - int(previous_stats[x_type])
        stats[x_type] = int(stats[x_type]) + int(current_stats[x_type])
    # Otherwise just the difference
    else:
        stats[x_type] = int(current_stats[x_type]) - int(previous_stats[x_type])

    # Calculate value per second.
    logger.debug(f"{x_type.upper()} since last check: {stats[x_type]}.")
    stats[x_type] = int(stats[x_type] / time_since_last_check)
    logger.debug(f"{x_type.upper()}/sec: {stats[x_type]}.")
    
    # Process data for graphing
    return_data += f"{x_type.lower()}={stats[x_type]};; "

    # Process return values
    if stats[x_type] >= thresholds[x_type]["crit"]:
        return_stat = max(return_stat, 2)
        return_text += f"{x_type.upper()} CRITICAL, {stats[x_type]}/s; "
    elif thresholds[x_type]["crit"] > stats[x_type] >= thresholds[x_type]["warn"]:
        return_stat = max(return_stat, 1)
        return_text += f"{x_type.upper()} WARNING, {stats[x_type]}/s; "
    elif thresholds[x_type]["warn"] > stats[x_type]:
        return_text += f"{x_type.upper()} OK, {stats[x_type]}/s; "
    else:
        return_stat = 3
        return_text += f"{x_type.upper()} UNKNOWN, {stats[x_type]}/s; "


# Return data and exit.
if return_stat == 2:
    print(f"[CRITICAL] - {return_text} | {return_data}")
if return_stat == 1:
    print(f"[WARNING] - {return_text} | {return_data}")
if return_stat == 0:
    print(f"[OK] - {return_text} | {return_data}")

sys.exit(return_stat)
