/* sch_kickass.c - Kickass qdisc implementation
 *
 * Implements the Kickass router as a qdisc. Probably not suitable for public
 * release of any kind.
 */

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <net/netlink.h>
#include <net/genetlink.h>

#include <linux/skbuff.h>
#include <linux/netfilter_bridge.h>
#include <linux/types.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/average.h>
#include <linux/math64.h>

#include <net/pkt_sched.h>
#include <net/ip.h>
#include <net/route.h>

#define INT_MULT 10 /* Multiply by this for more precision! */
#include "sch_kickass.h"
#include "packet_sizes.h"

#define TQ_LENGTH 10


#define CAPACITY 1250 /* bytes per ms */
#define INIT_RATE_FACT 20
#define ALPHA_TEN 1 /* alpha * 10 */

struct kickass_sched_data {
/* Parameters */
    u64     capacity;
    u32     daemon_pid;

/* Variables */
    u64     this_tq_sum;
    u64     last_tq_tot;
    u64     this_tq_pkt;
    u32     next_tq; /* not used yet */
    u32     seq;
    u32     packet_size;
    u64     current_rate;
    u32     emergency; 

    struct kickass_watchdog watchdog;
    struct ewma rtt_ewma;
};

/* GLOBAL */
// WARN: ONLY SUPPORTS ONE AT A TIME
struct Qdisc *g_qdisc;
ktime_t update_int;
int unanswered;

/****************** UTILITY ***************************/
/* Stolen from ip_output.c because it wasn't exported :S */
static void ip_copy_metadata(struct sk_buff *to, struct sk_buff *from)
{
	to->pkt_type = from->pkt_type;
	to->priority = from->priority;
	to->protocol = from->protocol;
	skb_dst_drop(to);
	skb_dst_copy(to, from);
	to->dev = from->dev;
	to->mark = from->mark;

	/* Copy the flags to each fragment. */
	IPCB(to)->flags = IPCB(from)->flags;

#ifdef CONFIG_NET_SCHED
	to->tc_index = from->tc_index;
#endif
	nf_copy(to, from);
#if IS_ENABLED(CONFIG_NETFILTER_XT_TARGET_TRACE)
	to->nf_trace = from->nf_trace;
#endif
#if defined(CONFIG_IP_VS) || defined(CONFIG_IP_VS_MODULE)
	to->ipvs_property = from->ipvs_property;
#endif
	skb_copy_secmark(to, from);
}

/* lifted from ip_options.c */
void ip_options_fragment(struct sk_buff *skb)
{
	unsigned char *optptr = skb_network_header(skb) + sizeof(struct iphdr);
	struct ip_options *opt = &(IPCB(skb)->opt);
	int  l = opt->optlen;
	int  optlen;

	while (l > 0) {
		switch (*optptr) {
		case IPOPT_END:
			return;
		case IPOPT_NOOP:
			l--;
			optptr++;
			continue;
		}
		optlen = optptr[1];
		if (optlen<2 || optlen>l)
		  return;
		if (!IPOPT_COPIED(*optptr))
			memset(optptr, IPOPT_NOOP, optlen);
		l -= optlen;
		optptr += optlen;
	}
	opt->ts = 0;
	opt->rr = 0;
	opt->rr_needaddr = 0;
	opt->ts_needaddr = 0;
	opt->ts_needtime = 0;
}

/*************** Functions for the timer **************/
static enum hrtimer_restart kickass_watchdog(struct hrtimer *timer)
{
    struct kickass_watchdog *wd = container_of(timer, struct kickass_watchdog,
                                               timer);
    ktime_t kt_now; 
    struct kickass_sched_data *q = qdisc_priv(wd->qdisc);

    u64 old_rate = q->current_rate;
    u32 average_rtt = ewma_read(&q->rtt_ewma);
    u32 time_step = 10; // min(10, average_rtt)
    u64 last_tq;

    u64 rdiv;
    u64 ldiv, rmin, lmin;
    u64 new_rate;
    
    last_tq = q->this_tq_sum - q->last_tq_tot;
    last_tq = (last_tq / time_step) * INT_MULT;
    q->last_tq_tot = q->this_tq_sum;
    

    rdiv = 10 * average_rtt * average_rtt * q->capacity;

    if (q->capacity > last_tq) {
        u64 cap_minus_sum = q->capacity - last_tq;
        lmin = average_rtt * old_rate * ALPHA_TEN * time_step * cap_minus_sum;
        rmin = 10 * old_rate * wd->qdisc->qstats.backlog * time_step;
        if (lmin > rmin) {
            ldiv = lmin - rmin;
            new_rate = old_rate + div64_u64(ldiv, rdiv);
        } else {
            ldiv = rmin - lmin;
            div64_u64(ldiv, rdiv);
            new_rate = old_rate - div64_u64(ldiv, rdiv);
        }
     } else {
        u64 sum_minus_cap = last_tq - q->capacity;
        lmin = average_rtt * old_rate * ALPHA_TEN * time_step * sum_minus_cap;
        rmin = 10 * old_rate * wd->qdisc->qstats.backlog * time_step;
        if (lmin > rmin) {
            ldiv = lmin - rmin;
            new_rate = old_rate - div64_u64(ldiv, rdiv);
        } else {
            ldiv = rmin - lmin;
            new_rate = old_rate + div64_u64(ldiv, rdiv);
        }
    }

    //printk("rtt: %u packet size: %u rate: %llu new_rate: %llu last_tq: %llu ldiv: %llu rdiv: %llu q(t): %u\n",
    //       average_rtt, q->packet_size, old_rate, new_rate, last_tq, ldiv,
    //       rdiv, wd->qdisc->qstats.backlog);

    if (new_rate < (15 * INT_MULT))  {
        new_rate = 15 * INT_MULT;
        //printk("Bottom out... \n");
    }
    else if (new_rate > q->capacity)  {
        //printk("Max...\n");
        new_rate = q->capacity;
    }

    q->packet_size = lookup_size(new_rate/INT_MULT);
    q->current_rate = new_rate;
    

    /* Check the queue for emergency... */
    if (4 * (wd->qdisc->qstats.backlog) >= 3 * 1500 *  (wd->qdisc->limit)) {
        q->emergency = 1;
        printk("Emergency mode! %u\n", wd->qdisc->qstats.backlog);
    }
    else {
        q->emergency = 0;
    }

    //send_message_to_d(wd->qdisc);
    kt_now = hrtimer_cb_get_time(timer);
    hrtimer_forward(timer, kt_now, update_int);
    return HRTIMER_RESTART;
}

void kickass_watchdog_init(struct kickass_watchdog *wd,
                           struct Qdisc *qdisc) 
{
    hrtimer_init(&wd->timer, CLOCK_MONOTONIC, HRTIMER_MODE_ABS);
    wd->timer.function = kickass_watchdog;
    wd->qdisc = qdisc;
}

void kickass_watchdog_schedule(struct kickass_watchdog *wd,
                               ktime_t expires)
{
    hrtimer_start(&wd->timer,
                   expires, 
                   HRTIMER_MODE_ABS);
}

void kickass_watchdog_cancel(struct kickass_watchdog *wd)
{
    hrtimer_cancel(&wd->timer);
}


/*************** qdisc functions **********************/

static int kickass_init(struct Qdisc *sch, struct nlattr *opt)
{
	//bool bypass;
    struct kickass_sched_data *q = qdisc_priv(sch);


	if (opt == NULL) {
		u32 limit = qdisc_dev(sch)->tx_queue_len ? : 1;
		sch->limit = limit;
	} else {
		struct tc_fifo_qopt *ctl = nla_data(opt);

		if (nla_len(opt) < sizeof(*ctl))
			return -EINVAL;

		sch->limit = ctl->limit;
	}

    /* Zero out */
    q->this_tq_sum = 0;
    q->this_tq_pkt = 0;
    q->last_tq_tot = 0;
    q->current_rate = (CAPACITY / INIT_RATE_FACT) * INT_MULT;
    q->packet_size = lookup_size(q->current_rate/INT_MULT); 
    q->daemon_pid = 0;
    q->seq = 0;
    q->emergency = 0;
    q->capacity = CAPACITY * INT_MULT;
    unanswered = 0;
    /* Set up the timer */
    update_int = ktime_set(0, TQ_LENGTH * ((unsigned long)1E6L));
    kickass_watchdog_init(&q->watchdog, sch);
    kickass_watchdog_schedule(&q->watchdog, update_int);
    /* Set up the moving avg */
    ewma_init(&q->rtt_ewma, 2048, 64); /* Requires a fixed weight */
    ewma_add(&q->rtt_ewma, 100); //Add 100 ms as base
    
    g_qdisc = sch;

	return 0;
}

static int kickass_enqueue(struct sk_buff *skb, struct Qdisc *sch)
{
    struct kickass_sched_data *q = qdisc_priv(sch);
    u32 signaled_rtt;

    struct iphdr *iph;

    q->this_tq_sum += qdisc_pkt_len(skb);
    q->this_tq_pkt++;

    /* Add to the RTT average if necessary */
    iph = ip_hdr(skb);
    if ((ntohs(iph->frag_off) & IP_DF) &&
        (is_signal_rtt(ntohs(iph->tot_len)) == 1)) {
        signaled_rtt = lookup_rtt(ntohs(iph->tot_len));
        ewma_add(&q->rtt_ewma, signaled_rtt);
    }
        
	if (likely(skb_queue_len(&sch->q) < sch->limit))
		return qdisc_enqueue_tail(skb, sch);

	return qdisc_reshape_fail(skb, sch);
}

static int kickass_enqueue_head(struct Qdisc *sch, struct sk_buff *skb)
{
    __skb_queue_head(&sch->q, skb);
    sch->qstats.backlog += qdisc_pkt_len(skb);
    return 1;
}

struct sk_buff *kickass_fragment(struct Qdisc *sch, struct sk_buff *skb, u64 size)
{
    unsigned int len, iphl, left, ll_rs, mac_len;
    unsigned int true_len;
    unsigned int offset;
    struct iphdr *iph;
    //struct tcphdr *old;
    //struct tcphdr *new;
    int protocol = ip_hdr(skb)->protocol;
    int ptr;
    struct sk_buff *skb2;
    struct sk_buff *skb3;
 
    /* emergency, legacy mode */

    /* No data yet */
    if (size == 0) {
        return skb;
    }

    /* We only care about TCP traffic */
    if (protocol != 6) {
        return skb;
    }
    len = skb->len;
    iph = ip_hdr(skb);
    iphl = iph->ihl * 4;
    mac_len = skb->network_header - skb->mac_header;
    true_len = len - iphl - mac_len;

    /* Is this packet big enough for us to care? */
    if ((len < 75) || (true_len < size)) {
        return skb;
    }
    /* Does it say not to fragment ? */
    if (ntohs(iph->frag_off) & IP_DF) {
        return skb;
    }

    /************ Begin Fragmenting *****************/
    offset = (ntohs(iph->frag_off) & IP_OFFSET) << 3; // better be 0
    /* Was it the first packet?*/
    if (offset != 0) {
        //printk("Already a fragment! Leaving...\n");
        return skb;
    }
    ptr = iphl + mac_len;
    left = true_len - size; // true size is the size in payload 
    len = size;
    // Some kind of checksum thing?
    if ((skb->ip_summed == CHECKSUM_PARTIAL) && skb_checksum_help(skb)) {
        printk("Uh oh!\n");
        return skb;
    }
    //ll_rs = LL_RESERVED_SPACE_EXTRA(rt->dst.dev, nf_bridge_pad(skb));
    ll_rs = skb->mac_header; // Assuming the header starts here
    /* Allocate a new buffer */
    if ((skb2 = alloc_skb(len + iphl + mac_len + ll_rs, GFP_ATOMIC)) == NULL) {
        return skb; // TODO: Non silent failure better?
    }
    /* Copy all the metadata */
    ip_copy_metadata(skb2, skb);
    /* Uhm for lower level? */
    skb_reserve(skb2, ll_rs);
    /* copy the buffer */
    skb_put(skb2, len + iphl + mac_len); //Push tail down 
    /* Copy the header... dunno what's in there */
    memcpy(skb2->head, skb->head, skb->mac_header); //data starts at top of mac
    /* Set all three headers */
    skb_reset_mac_header(skb2);
    skb2->network_header  = skb2->mac_header + mac_len;
    skb2->transport_header = skb2->network_header + iphl; 
    /* Set the owner */
    if (skb->sk) {
        skb_set_owner_w(skb2, skb->sk);
    }
    /* Actually copy everything */ 
    skb_copy_from_linear_data(skb, skb_mac_header(skb2), mac_len);
    skb_copy_from_linear_data_offset(skb, mac_len, skb_network_header(skb2), iphl);
    /* copy the transport header and data */
    //skb_copy_from_linear_data_offset(skb, mac_len + iphl, skb_transport_header(skb2), len); 
    skb_copy_bits(skb, mac_len + iphl, skb_transport_header(skb2), len);

    /* Set the new header options */
    iph = ip_hdr(skb2);
    iph->frag_off = htons((offset >> 3));
    /* Clears options not allowed further down so later packets that copy from 
     * skb will have them cleared. */
    if (offset == 0) {
        ip_options_fragment(skb);
    }
    /* Set MF */
    if (left > 0) {
      iph->frag_off |= htons(IP_MF);
    }
    /* Set the length, compute the checksum */
    iph->tot_len = htons(len + iphl);
    ip_send_check(iph);

    /* Advance the lengths */
    ptr += len;
    offset += len;
   
    /*Remainder fragment */
    if ((skb3 = alloc_skb(left + iphl + mac_len + ll_rs, GFP_ATOMIC)) == NULL) {
       //???? 
    }
    ip_copy_metadata(skb3, skb);
    skb_reserve(skb3, ll_rs);
    skb_put(skb3, left + iphl + mac_len); //Push tail down 
    
    /* Copy the header... dunno what's in there */
    memcpy(skb3->head, skb->head, skb->mac_header); //data starts at top of mac
    skb_reset_mac_header(skb3);
    skb3->network_header  = skb3->mac_header + mac_len;
    skb3->transport_header = skb3->network_header + iphl; 
    /* Set the owner */
    if (skb->sk) {
        skb_set_owner_w(skb3, skb->sk);
    }
    /* Actually copy everything */ 
    skb_copy_from_linear_data(skb, skb_mac_header(skb3), mac_len);
    skb_copy_from_linear_data_offset(skb, mac_len, skb_network_header(skb3), iphl);
    /* Only copy what's left */
    //skb_copy_from_linear_data_offset(skb, mac_len + iphl + len, skb_transport_header(skb3), left); 
    skb_copy_bits(skb, mac_len + iphl + len, skb_transport_header(skb3), left);
    
    iph = ip_hdr(skb3);
    iph->frag_off = htons((offset >> 3));
    iph->tot_len = htons(left + iphl);
    ip_send_check(iph);

    /* Now reque the remainder */
    kickass_enqueue_head(sch, skb3); 

    /* Get rid of the original */
    consume_skb(skb);

    return skb2;
}

static struct sk_buff *kickass_dequeue(struct Qdisc *sch)
{
    struct sk_buff *skb;
    struct sk_buff *skb2;
    struct kickass_sched_data *q = qdisc_priv(sch);

    skb = qdisc_dequeue_head(sch);
    /* Fragment it!*/ 
    if (skb == NULL)
    {
        return skb;
    }

    if (q->emergency == 1) {
        return skb;
    }

    skb2 = kickass_fragment(sch, skb,q->packet_size); 
    return skb2;
}


static int kickass_dump(struct Qdisc *sch, struct sk_buff *skb)
{
	struct tc_fifo_qopt opt = { .limit = sch->limit };
    //struct kickass_sched_data *q = qdisc_priv(sch);

	if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt))
		goto nla_put_failure;
	return skb->len;

nla_put_failure:
	return -1;
}

static void kickass_destroy(struct Qdisc *sch)
{
    struct kickass_sched_data *q = qdisc_priv(sch);
    kickass_watchdog_cancel(&q->watchdog);
    
    /* That's it for now */
}

struct Qdisc_ops kickass_qdisc_ops __read_mostly = {
	.id		=	"kickass",
	.priv_size	=	sizeof(struct kickass_sched_data),
	.enqueue	=	kickass_enqueue,
	.dequeue	=	kickass_dequeue,
//    .dequeue    =   qdisc_dequeue_head,
	.peek		=	qdisc_peek_head,
	.drop		=	qdisc_queue_drop,
	.init		=	kickass_init,
	.reset		=	qdisc_reset_queue,
    .destroy    =   kickass_destroy,
	.change		=	kickass_init,
	.dump		=	kickass_dump,
	.owner		=	THIS_MODULE,
};

static int __init kickass_module_init(void)
{
    //int rc;
    int rq;

    printk(KERN_INFO "Loading and registering Kickass\n");
    
    rq = register_qdisc(&kickass_qdisc_ops);
    if (rq != 0) 
    {
        printk("Error: Failed to register qdisc!\n");
        return -1;
    }

    return 0;
}

static void __exit kickass_module_exit(void)
{
    printk(KERN_INFO "Removing Kickass\n");
    //genl_unregister_ops(&ka_gnl_family, &ka_gnl_ops_set_rate);
    //genl_unregister_ops(&ka_gnl_family, &ka_gnl_ops_daemon_connect);
    //genl_unregister_family(&ka_gnl_family);
	unregister_qdisc(&kickass_qdisc_ops);
}

module_init(kickass_module_init);
module_exit(kickass_module_exit);

/* Module Info */
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Marcel Flores");
MODULE_DESCRIPTION("A qdisc implementation for kickass cc.");
