#!/usr/bin/env @PYTHON_SHEBANG@
#
# Print out statistics for all zil stats. This information is
# available through the zil kstat.
#
# CDDL HEADER START
#
# The contents of this file are subject to the terms of the
# Common Development and Distribution License, Version 1.0 only
# (the "License").  You may not use this file except in compliance
# with the License.
#
# You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
# or https://opensource.org/licenses/CDDL-1.0.
# See the License for the specific language governing permissions
# and limitations under the License.
#
# When distributing Covered Code, include this CDDL HEADER in each
# file and include the License file at usr/src/OPENSOLARIS.LICENSE.
# If applicable, add the following below this CDDL HEADER, with the
# fields enclosed by brackets "[]" replaced with your own identifying
# information: Portions Copyright [yyyy] [name of copyright owner]
#
# This script must remain compatible with Python 3.6+.
#

import sys
import subprocess
import time
import copy
import os
import re
import signal
from collections import defaultdict
import argparse
from argparse import RawTextHelpFormatter

cols = {
	# hdr:       [size,      scale, 	 kstat name]
	"time":      [8,         -1,         "time"],
	"pool":      [12,        -1,         "pool"],
	"ds":        [12,        -1,         "dataset_name"],
	"obj":       [12,        -1,         "objset"],
	"zcc":       [10,        1000,       "zil_commit_count"],
	"zcwc":      [10,        1000,       "zil_commit_writer_count"],
	"ziic":      [10,        1000,       "zil_itx_indirect_count"],
	"zic":       [10,        1000,       "zil_itx_count"],
	"ziib":      [10,        1024,       "zil_itx_indirect_bytes"],
	"zicc":      [10,        1000,       "zil_itx_copied_count"],
	"zicb":      [10,        1024,       "zil_itx_copied_bytes"],
	"zinc":      [10,        1000,       "zil_itx_needcopy_count"],
	"zinb":      [10,        1024,       "zil_itx_needcopy_bytes"],
	"zimnc":     [10,        1000,       "zil_itx_metaslab_normal_count"],
	"zimnb":     [10,        1024,       "zil_itx_metaslab_normal_bytes"],
	"zimsc":     [10,        1000,       "zil_itx_metaslab_slog_count"],
	"zimsb":     [10,        1024,       "zil_itx_metaslab_slog_bytes"],
}

hdr = ["time", "pool", "ds", "obj", "zcc", "zcwc", "ziic", "zic", "ziib", \
	"zicc", "zicb", "zinc", "zinb", "zimnc", "zimnb", "zimsc", "zimsb"]

ghdr = ["time", "zcc", "zcwc", "ziic", "zic", "ziib", "zicc", "zicb",
	"zinc", "zinb", "zimnc", "zimnb", "zimsc", "zimsb"]

cmd = ("Usage: zilstat [-hgdv] [-i interval] [-p pool_name]")

curr = {}
diff = {}
kstat = {}
ds_pairs = {}
pool_name = None
dataset_name = None
interval = 0
sep = "  "
gFlag = True
dsFlag = False

def prettynum(sz, scale, num=0):
	suffix = [' ', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']
	index = 0
	save = 0

	if scale == -1:
		return "%*s" % (sz, num)

	# Rounding error, return 0
	elif 0 < num < 1:
		num = 0

	while num > scale and index < 5:
		save = num
		num = num / scale
		index += 1

	if index == 0:
		return "%*d" % (sz, num)

	if (save / scale) < 10:
		return "%*.1f%s" % (sz - 1, num, suffix[index])
	else:
		return "%*d%s" % (sz - 1, num, suffix[index])

def print_header():
	global hdr
	global sep
	for col in hdr:
		new_col = col
		if interval > 0 and col not in ['time', 'pool', 'ds', 'obj']:
			new_col += "/s"
		sys.stdout.write("%*s%s" % (cols[col][0], new_col, sep))
	sys.stdout.write("\n")

def print_values(v):
	global hdr
	global sep
	for col in hdr:
		val = v[cols[col][2]]
		if col not in ['time', 'pool', 'ds', 'obj'] and interval > 0:
			val = v[cols[col][2]] // interval
		sys.stdout.write("%s%s" % (
			prettynum(cols[col][0], cols[col][1], val), sep))
	sys.stdout.write("\n")

def print_dict(d):
	for pool in d:
		for objset in d[pool]:
			print_values(d[pool][objset])

def detailed_usage():
	sys.stderr.write("%s\n" % cmd)
	sys.stderr.write("Field definitions are as follows:\n")
	for key in cols:
		sys.stderr.write("%11s : %s\n" % (key, cols[key][2]))
	sys.stderr.write("\n")
	sys.exit(0)

def init():
	global pool_name
	global dataset_name
	global interval
	global hdr
	global curr
	global gFlag
	global sep

	curr = dict()

	parser = argparse.ArgumentParser(description='Program to print zilstats',
                                	 add_help=True,
					 formatter_class=RawTextHelpFormatter,
					 epilog="\nUsage Examples\n"\
				 		"Note: Global zilstats is shown by default,"\
						" if none of a|p|d option is not provided\n"\
				 		"\tzilstat -a\n"\
						'\tzilstat -v\n'\
						'\tzilstat -p tank\n'\
						'\tzilstat -d tank/d1,tank/d2,tank/zv1\n'\
						'\tzilstat -i 1\n'\
						'\tzilstat -s \"***\"\n'\
						'\tzilstat -f zcwc,zimnb,zimsb\n')

	parser.add_argument(
		"-v", "--verbose",
		action="store_true",
		help="List field headers and definitions"
	)

	pool_grp = parser.add_mutually_exclusive_group()

	pool_grp.add_argument(
		"-a", "--all",
		action="store_true",
		dest="all",
		help="Print all dataset stats"
	)

	pool_grp.add_argument(
		"-p", "--pool",
		type=str,
		help="Print stats for all datasets of a speicfied pool"
	)

	pool_grp.add_argument(
		"-d", "--dataset",
		type=str,
		help="Print given dataset(s) (Comma separated)"
	)

	parser.add_argument(
		"-f", "--columns",
		type=str,
		help="Specify specific fields to print (see -v)"
	)

	parser.add_argument(
		"-s", "--separator",
		type=str,
		help="Override default field separator with custom "
			 "character or string"
	)

	parser.add_argument(
		"-i", "--interval",
		type=int,
		dest="interval",
		help="Print stats between specified interval"
			 " (in seconds)"
	)

	parsed_args = parser.parse_args()

	if parsed_args.verbose:
		detailed_usage()

	if parsed_args.all:
		gFlag = False

	if parsed_args.interval:
		interval = parsed_args.interval

	if parsed_args.pool:
		pool_name = parsed_args.pool
		gFlag = False

	if parsed_args.dataset:
		dataset_name = parsed_args.dataset
		gFlag = False

	if parsed_args.separator:
		sep = parsed_args.separator

	if gFlag:
		hdr = ghdr

	if parsed_args.columns:
		hdr = parsed_args.columns.split(",")

		invalid = []
		for ele in hdr:
			if gFlag and ele not in ghdr:
				invalid.append(ele)
			elif ele not in cols:
				invalid.append(ele)

		if len(invalid) > 0:
			sys.stderr.write("Invalid column definition! -- %s\n" % invalid)
			sys.exit(1)

	if pool_name and dataset_name:
		print ("Error: Can not filter both dataset and pool")
		sys.exit(1)

def FileCheck(fname):
	try:
		return (open(fname))
	except IOError:
		print ("Unable to open zilstat proc file: " + fname)
		sys.exit(1)

if sys.platform.startswith('freebsd'):
	# Requires py-sysctl on FreeBSD
	import sysctl

	def kstat_update(pool = None, objid = None):
		global kstat
		kstat = {}
		if not pool:
			file = "kstat.zfs.misc.zil"
			k = [ctl for ctl in sysctl.filter(file) \
				if ctl.type != sysctl.CTLTYPE_NODE]
			kstat_process_str(k, file, "GLOBAL", len(file + "."))
		elif objid:
			file = "kstat.zfs." + pool + ".dataset.objset-" + objid
			k = [ctl for ctl in sysctl.filter(file) if ctl.type \
				!= sysctl.CTLTYPE_NODE]
			kstat_process_str(k, file, objid, len(file + "."))
		else:
			file = "kstat.zfs." + pool + ".dataset"
			zil_start = len(file + ".")
			obj_start = len("kstat.zfs." + pool + ".")
			k = [ctl for ctl in sysctl.filter(file)
				if ctl.type != sysctl.CTLTYPE_NODE]
			for s in k:
				if not s or (s.name.find("zil") == -1 and \
					s.name.find("dataset_name") == -1):
					continue
				name, value = s.name, s.value
				objid = re.findall(r'0x[0-9A-F]+', \
					name[obj_start:], re.I)[0]
				if objid not in kstat:
					kstat[objid] = dict()
				zil_start = len(file + ".objset-" + \
					objid + ".")
				kstat[objid][name[zil_start:]] = value \
					if (name.find("dataset_name")) \
					else int(value)

	def kstat_process_str(k, file, objset = "GLOBAL", zil_start = 0):
			global kstat
			if not k:
				print("Unable to process kstat for: " + file)
				sys.exit(1)
			kstat[objset] = dict()
			for s in k:
				if not s or (s.name.find("zil") == -1 and \
				    s.name.find("dataset_name") == -1):
					continue
				name, value = s.name, s.value
				kstat[objset][name[zil_start:]] = value \
				    if (name.find("dataset_name")) else int(value)

elif sys.platform.startswith('linux'):
	def kstat_update(pool = None, objid = None):
		global kstat
		kstat = {}
		if not pool:
			k = [line.strip() for line in \
				FileCheck("/proc/spl/kstat/zfs/zil")]
			kstat_process_str(k, "/proc/spl/kstat/zfs/zil")
		elif objid:
			file = "/proc/spl/kstat/zfs/" + pool + "/objset-" + objid
			k = [line.strip() for line in FileCheck(file)]
			kstat_process_str(k, file, objid)
		else:
			if not os.path.exists(f"/proc/spl/kstat/zfs/{pool}"):
				print("Pool \"" + pool + "\" does not exist, Exitting")
				sys.exit(1)
			objsets = os.listdir(f'/proc/spl/kstat/zfs/{pool}')
			for objid in objsets:
				if objid.find("objset-") == -1:
					continue
				file = "/proc/spl/kstat/zfs/" + pool + "/" + objid
				k = [line.strip() for line in FileCheck(file)]
				kstat_process_str(k, file, objid.replace("objset-", ""))

	def kstat_process_str(k, file, objset = "GLOBAL", zil_start = 0):
			global kstat
			if not k:
				print("Unable to process kstat for: " + file)
				sys.exit(1)

			kstat[objset] = dict()
			for s in k:
				if not s or (s.find("zil") == -1 and \
				    s.find("dataset_name") == -1):
					continue
				name, unused, value = s.split()
				kstat[objset][name] = value \
				    if (name == "dataset_name") else int(value)

def zil_process_kstat():
	global curr, pool_name, dataset_name, dsFlag, ds_pairs
	curr.clear()
	if gFlag == True:
		kstat_update()
		zil_build_dict()
	else:
		if pool_name:
			kstat_update(pool_name)
			zil_build_dict(pool_name)
		elif dataset_name:
			if dsFlag == False:
				dsFlag = True
				datasets = dataset_name.split(',')
				ds_pairs = defaultdict(list)
				for ds in datasets:
					try:
						objid = subprocess.check_output(['zfs',
						    'list', '-Hpo', 'objsetid', ds], \
						    stderr=subprocess.DEVNULL) \
						    .decode('utf-8').strip()
					except subprocess.CalledProcessError as e:
						print("Command: \"zfs list -Hpo objset "\
						+ str(ds) + "\" failed with error code:"\
						+ str(e.returncode))
						print("Please make sure that dataset \""\
						+ str(ds) + "\" exists")
						sys.exit(1)
					if not objid:
						continue
					ds_pairs[ds.split('/')[0]]. \
						append(hex(int(objid)))
			for pool, objids in ds_pairs.items():
				for objid in objids:
					kstat_update(pool, objid)
					zil_build_dict(pool)
		else:
			try:
				pools = subprocess.check_output(['zpool', 'list', '-Hpo',\
				    'name']).decode('utf-8').split()
			except subprocess.CalledProcessError as e:
				print("Command: \"zpool list -Hpo name\" failed with error"\
				    "code: " + str(e.returncode))
				sys.exit(1)
			for pool in pools:
				kstat_update(pool)
				zil_build_dict(pool)

def calculate_diff():
	global curr, diff
	prev = copy.deepcopy(curr)
	zil_process_kstat()
	diff = copy.deepcopy(curr)
	for pool in curr:
		for objset in curr[pool]:
			for col in hdr:
				if col not in ['time', 'pool', 'ds', 'obj']:
					key = cols[col][2]
					# If prev is NULL, this is the
					# first time we are here
					if not prev:
						diff[pool][objset][key] = 0
					else:
						diff[pool][objset][key] \
							= curr[pool][objset][key] \
							- prev[pool][objset][key]

def zil_build_dict(pool = "GLOBAL"):
	global kstat
	for objset in kstat:
		for key in kstat[objset]:
			val = kstat[objset][key]
			if pool not in curr:
				curr[pool] = dict()
			if objset not in curr[pool]:
				curr[pool][objset] = dict()
			curr[pool][objset][key] = val
		curr[pool][objset]["pool"] = pool
		curr[pool][objset]["objset"] = objset
		curr[pool][objset]["time"] = time.strftime("%H:%M:%S", \
			time.localtime())

def sign_handler_epipe(sig, frame):
	print("Caught EPIPE signal: " + str(frame))
	print("Exitting...")
	sys.exit(0)

def main():
	global interval
	global curr
	hprint = False
	init()
	signal.signal(signal.SIGINT, signal.SIG_DFL)
	signal.signal(signal.SIGPIPE, sign_handler_epipe)

	if interval > 0:
		while True:
			calculate_diff()
			if not diff:
				print ("Error: No stats to show")
				sys.exit(0)
			if hprint == False:
				print_header()
				hprint = True
			print_dict(diff)
			time.sleep(interval)
	else:
		zil_process_kstat()
		if not curr:
			print ("Error: No stats to show")
			sys.exit(0)
		print_header()
		print_dict(curr)

if __name__ == '__main__':
	main()