#!/usr/bin/env python3
# Copyright (C) 2026 Spearhead Systems SRL
# Copyright (C) 2019 Checkmk GmbH

import argparse
import json
import sys

import requests
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


def main(argv=None):
    args = parse_arguments(argv)
    if args.demo:
        return print_demo()
    handle_request(args)


def handle_request(args):  # pylint: disable=too-many-branches
    url_base = f"{args.proto}://{args.hostname}:{args.port}/api"
    url = url_base + "/system/metrics"
    value = handle_response(url, args).json()

    # Handle the input_metrics section. We need to merge information from
    # both inputstates and input_metrics, and we do that by extracting the
    # ids returned by both calls. Once merged, we return a single dictionary
    # with state, name, type, port, and the various rates (either raw or
    # incomingMessages).
    url_inputs_data = url_base + "/cluster/inputstates"
    inputs_data = handle_response(url_inputs_data, args).json()
    inputs_data = tuple(inputs_data.values())[0]
    metrics_data = value.get("meters")

    if inputs_data is None or metrics_data is None:
        return

    # Create a dictionary, containing all metrics with substrings
    # "incomingMessages" and "rawSize".
    # All rates should exist, created and with values as low as "0.0".
    metrics_dict = {}
    for metric, metric_rate in metrics_data.items():
        metric_id = metric.split(".")[-2]
        metric_type = None
        if "incomingMessages" in metric:
            metric_type = "im"
        elif "rawSize" in metric:
            metric_type = "rs"
        if metric_type:
            metric_key = metrics_dict.setdefault(metric_id, {})
            metric_key[f"{metric_type}_m1_rate" ] = metric_rate["m1_rate"]
            metric_key[f"{metric_type}_m5_rate" ] = metric_rate["m5_rate"]
            metric_key[f"{metric_type}_m15_rate"] = metric_rate["m15_rate"]

    # Create a dictionary with all inputs and add the rates from
    # the previous dictionary, metrics_dict. This is passed as output.
    # Some inputs don't have a "port", so we handle this with .get("port").
    inputs_dict = {}
    for inputs in inputs_data:
        message_input = inputs["message_input"]
        input_id = inputs["id"]
        input_state = inputs["state"]
        input_name = message_input["title"]
        input_type = message_input["name"]
        input_port = message_input["attributes"].get("port")
        input_rate = metrics_dict[input_id]

        inputs_dict[input_id] = {
            "input_state": input_state,
            "input_name":  input_name,
            "input_type":  input_type,
            "input_port":  input_port,
            "im_m1_rate":  input_rate["im_m1_rate"],
            "im_m5_rate":  input_rate["im_m5_rate"],
            "im_m15_rate": input_rate["im_m15_rate"],
            "rs_m1_rate":  input_rate["rs_m1_rate"],
            "rs_m5_rate":  input_rate["rs_m5_rate"],
            "rs_m15_rate": input_rate["rs_m15_rate"]
        }

    if inputs_dict:
        handle_output(inputs_dict)


def handle_response(url, args):
    try:
        return requests.get(url, auth=(args.user, args.password), verify=not args.no_cert_check)
    except requests.exceptions.RequestException as e:
        sys.stderr.write(f"Error: {e}")


def handle_output(value):
    print("<<<graylog_input_metrics:sep(0)>>>")
    print(json.dumps(value))


def print_demo():
    print("""
<<<graylog_input_metrics:sep(0)>>>
{"641e88d05d447a677efde199": {"input_state": "FAILED", "input_name": "kafka_cef_test", "input_type": "CEF Kafka", "input_port": null, "im_m1_rate": 0.0, "im_m5_rate": 0.0, "im_m15_rate": 0.0, "rs_m1_rate": 0.0, "rs_m5_rate": 0.0, "rs_m15_rate": 0.0}, "641e32885d447a677efd2dbf": {"input_state": "RUNNING", "input_name": "UDP-test", "input_type": "Syslog UDP", "input_port": 1514, "im_m1_rate": 1.0846244336700077, "im_m5_rate": 1.3700826278955827, "im_m15_rate": 1.254406787430692, "rs_m1_rate": 145.45579305762527, "rs_m5_rate": 180.6486220431909, "rs_m15_rate": 165.26666376319292}, "641e32795d447a677efd2d9e": {"input_state": "RUNNING", "input_name": "testTCP", "input_type": "Syslog TCP", "input_port": 1515, "im_m1_rate": 1.057872514816615, "im_m5_rate": 1.364957693749168, "im_m15_rate": 1.2528742858546844, "rs_m1_rate": 140.4719944116262, "rs_m5_rate": 178.57816158901215, "rs_m15_rate": 163.80530659055356}}
    """.strip())


def parse_arguments(argv):
    parser = argparse.ArgumentParser(description=__doc__)

    parser.add_argument("-u", "--user", default=None, help="Username for graylog login")
    parser.add_argument("-s", "--password", default=None, help="Password for graylog login")
    parser.add_argument(
        "-P",
        "--proto",
        default="https",
        help="Use 'http' or 'https' for connection to graylog (default=https)",
    )
    parser.add_argument(
        "-p", "--port", default=443, type=int, help="Use alternative port (default: 443)"
    )
    parser.add_argument(
        "--no-cert-check", action="store_true", help="Disable SSL certificate validation"
    )
    parser.add_argument(
        "-d", "--demo", action="store_true", help="Return demo data"
    )

    parser.add_argument(
        "hostname", metavar="HOSTNAME", help="Name of the graylog instance to query."
    )

    return parser.parse_args(argv)


if __name__ == "__main__":
    main()
