#include <stdio.h>
#include <memory.h>
#include <timer.h>
#include <ip.h>
#include "arp.h"

#define ARP_OPCODE_REQUEST  1
#define ARP_OPCODE_REPLY    2


typedef struct _ARP_FRAME   ARP_FRAME;
typedef struct _FRAME       FRAME;

struct _ARP_FRAME
{
    WORD    typeHard;
    WORD    typeProt;
    BYTE    lengthHard;
    BYTE    lengthProt;
    WORD    opcode;
    BYTE    *senderHard;
    BYTE    *senderProt;
    BYTE    *targetHard;
    BYTE    *targetProt;
};

struct _FRAME
{
    WORD    typeHard;
    WORD    typeProt;
    BYTE    lengthHard;
    BYTE    lengthProt;
    WORD    opcode;
};




static ARP_DESCR    *DescrFind(MAC_IFACE *iface, WORD prot);
static ARP_ENTRY    *EntryFind(ARP_DESCR *descr, BYTE *addr);
static ARP_ENTRY    *EntryUpdate(ARP_DESCR *descr,
                        BYTE *addrProt, BYTE *addrHard);
static ARP_ENTRY    *EntryCreate(ARP_DESCR *descr, BYTE *addrProt,
                                        BYTE *addrHard, BOOLEAN dynamic);
static void         EntryClean(TIMER_DESCR *timer, ULONG now, void *parm);
static BOOLEAN      RequestReply(ARP_DESCR *descr, BYTE *addr, ARP_ENTRY *entry);
static void         RequestTimeout(TIMER_DESCR *timer, ULONG now, void *parm);
static ARP_REQUEST  *RequestCreate(ARP_DESCR *descr, BYTE *addr);
static ARP_REQUEST  *RequestFind(ARP_DESCR *descr, BYTE *addr);
static BOOLEAN      Rcve(MAC_IFACE *iface, CHAIN *chain, MAC_HDR *hdr);
static BOOLEAN      Send(ARP_DESCR *descr, WORD opcode,
                            BYTE *addrProt, BYTE *addrHard);
static CHAIN        *FrameEncode(CHAIN *chain, ARP_FRAME *arp);
static CHAIN        *FrameDecode(CHAIN *chain, ARP_FRAME *arp);

static void         Dump(ARP_DESCR *d);





MAC_PROT arpMac =
{
    Rcve,
    MAC_PROT_ARP
};



ARP_DESCR   *arpDescrList = 0;
TIMER_DESCR *arpCleanTimer = 0;


BOOLEAN ArpInit(void)
{
    static BOOLEAN init = FALSE;
    
    if (!init)
    {
        if (MacProtRegister(&arpMac))
        {
            arpCleanTimer = TimerRegister(EntryClean, 0, ARP_ENTRY_TIMEOUT,
                TIMER_FOREVER, TIMER_TYPE_SKIP);
            if (arpCleanTimer != 0)
                init = TRUE;
        }
    }

    return init;
}


BOOLEAN ArpRegister(ARP_DESCR *arp)
{
    arp->requestList    = 0;
    arp->entryList      = 0;
    arp->next           = arpDescrList;
    arpDescrList        = arp;

    return TRUE;
}


void ArpRemove(ARP_DESCR *arp)
{
    ARP_DESCR **p;

    for (p=&arpDescrList; *p!=0; p=&(*p)->next)
    {
        if (*p==arp)
        {
            *p = (*p)->next;
        }
    }
}


BOOLEAN ArpRequest(MAC_IFACE *iface, WORD prot, BYTE *addr,
                ARP_REPLY Reply, void *specific)
{
    ARP_DESCR   *d;
    ARP_REQUEST *r;
    ARP_QUE     *q;
    
    d = DescrFind(iface, prot);
    if (d==0)
        return FALSE;
    r = RequestFind(d, addr);
    if (r == 0)
    {
        r = RequestCreate(d, addr);
        if (r == 0)
            return FALSE;
    }
    q = DnpapMalloc(sizeof(ARP_QUE));
    if (q == 0)
        return FALSE;
    q->Reply    = Reply;
    q->specific = specific;
    q->next     = r->queList;
    r->queList  = q;
    return TRUE;
}


ARP_ENTRY *ArpFind(MAC_IFACE *iface, WORD prot, BYTE *addr)
{
    ARP_DESCR   *d;
    ARP_ENTRY   *e;

    d = DescrFind(iface, prot);
    if (d==0)
        return 0;
    e = EntryFind(d, addr);
    return e;
}










static ARP_DESCR *DescrFind(MAC_IFACE *iface, WORD prot)
{
    ARP_DESCR *d;

    for (d=arpDescrList; d!=0; d=d->next)
    {
        if (d->prot == prot && d->iface == iface)
            break;
    }
    return d;
}


static ARP_ENTRY *EntryFind(ARP_DESCR *descr, BYTE *addr)
{
    ARP_ENTRY *e;

    
    for (e=descr->entryList; e!=0; e=e->next)
    {
        if (memcmp(e->addrProt, addr, descr->addrLength) == 0)
            break;
    }
    return e;
}



static ARP_ENTRY *EntryUpdate(ARP_DESCR *descr,
                        BYTE *addrProt, BYTE *addrHard)
{
    ARP_ENTRY *e;

    for (e=descr->entryList; e!=0; e=e->next)
    {
        if (memcmp(e->addrProt, addrProt, descr->addrLength) == 0)
            break;
    }
    if (e==0 || !e->dynamic)
        return 0;
    memcpy(e->addrHard, addrHard, descr->iface->addrLength);
    e->update = TRUE;
    return e;
        
}


static ARP_ENTRY *EntryCreate(ARP_DESCR *descr, BYTE *addrProt,
                                        BYTE *addrHard, BOOLEAN dynamic)
{
    ARP_ENTRY *e;

    for (e=descr->entryList; e!=0; e=e->next)
    {
        if (memcmp(e->addrProt, addrProt, descr->addrLength) == 0)
            break;
    }
    if (e==0)
    {
        e=DnpapMalloc(sizeof(ARP_ENTRY));
        if (e==0)
            return 0;
        memcpy(e->addrProt, addrProt, descr->addrLength);
        e->next = descr->entryList;
        descr->entryList = e;
    }
    memcpy(e->addrHard, addrHard, descr->iface->addrLength);
    e->dynamic = dynamic;
    e->update = TRUE;
    return e;
}

static void EntryClean(TIMER_DESCR *timer, ULONG now, void *parm)
{
    ARP_DESCR *d;
    ARP_ENTRY **p, *e;

    for (d=arpDescrList; d!=0; d=d->next)
    {
        p=&d->entryList;
        while (*p!=0)
        {
            e = *p;
            if (e->dynamic)
            {
                if (e->update)
                {
                    e->update = FALSE;
                    p=&(*p)->next;
                }
                else
                {
                    *p=(*p)->next;
                    DnpapFree(e);
                }
            }
        }
    }
}


static BOOLEAN RequestReply(ARP_DESCR *descr, BYTE *addr, ARP_ENTRY *entry)
{
    ARP_REQUEST **p, *request;
    ARP_QUE     *q, *que;
    
    for (p=&descr->requestList; *p!=0; p=&(*p)->next)
    {
        if (memcmp((*p)->addr, addr, descr->addrLength)==0)
            break;
    }
    request = *p;
    if (request==0)
        return FALSE;
    
    TimerRemove(request->timer);

    q=request->queList;
    while (q!=0)
    {
        que = q;
        q = q->next;
        if (que->Reply!=0)
            que->Reply(descr->iface, entry, que->specific);
        DnpapFree(que);
    }
    *p=(*p)->next;
    DnpapFree(request);
    return FALSE;
}



static ARP_REQUEST *RequestCreate(ARP_DESCR *descr, BYTE *addr)
{
    ARP_REQUEST *r;
    
    if (!Send(descr, ARP_OPCODE_REQUEST, addr, 0))
        return 0;
    r = DnpapMalloc(sizeof(ARP_REQUEST));
    if (r==0)
        return 0;
    memcpy(r->addr, addr, descr->addrLength);
    r->retry            = ARP_REQUEST_RETRIES;
    r->queList          = 0;
    r->timer            = TimerRegister(RequestTimeout, r,
        ARP_REQUEST_TIMEOUT, TIMER_FOREVER, TIMER_TYPE_SKIP);
    r->descr            = descr;
    r->next             = descr->requestList;
    descr->requestList  = r;
    
    
    return r;
}

static ARP_REQUEST *RequestFind(ARP_DESCR *descr, BYTE *addr)
{
    ARP_REQUEST *r;
    
    for (r=descr->requestList; r!=0; r=r->next)
    {
        if (memcmp(r->addr, addr, descr->addrLength)==0)
            break;
    }
    return r;
}


static void RequestTimeout(TIMER_DESCR *timer, ULONG now, void *parm)
{
    ARP_REQUEST *r;

    r = parm;
    if (--r->retry != 0)
    {
        Send(r->descr, ARP_OPCODE_REQUEST, r->addr, 0);
    }
    else
    {
        RequestReply(r->descr, r->addr, 0);
    }
}


static void Dump(ARP_DESCR *d)
{
    ARP_ENTRY   *e;
    ARP_REQUEST *r;


    printf("ARP table for iface %s\n", d->iface->descr);
    for (e = d->entryList; e!=0; e=e->next)
    {
        printf("%d.%d.%d.%d = %02x:%02x:%02x:%02x:%02x:%02x\n",
            e->addrProt[0],
            e->addrProt[1],
            e->addrProt[2],
            e->addrProt[3],
            e->addrHard[0],
            e->addrHard[1],
            e->addrHard[2],
            e->addrHard[3],
            e->addrHard[4],
            e->addrHard[5]);
    }                              
    for (r = d->requestList; r!=0; r=r->next)
    {
        printf("%d.%d.%d.%d  %d\n",
            r->addr[0],
            r->addr[1],
            r->addr[2],
            r->addr[3],
            r->retry);
    }                              
}



static BOOLEAN Rcve(MAC_IFACE *iface, CHAIN *chain, MAC_HDR *hdr)
{
    ARP_FRAME arp;
    ARP_DESCR *d;
    ARP_ENTRY *e;
    CHAIN     *new;
    BOOLEAN   success = TRUE;
    
    new = FrameDecode(chain, &arp);
    if (new != 0)
    {
        if (arp.typeHard == iface->arp)
        {
            d = DescrFind(iface, arp.typeProt);
            if (d != 0)
            {
                e = EntryUpdate(d, arp.senderProt, arp.senderHard);

                if (memcmp(d->addr, arp.targetProt, d->addrLength)==0)
                {

                    if (e==0)
                    {
                        e = EntryCreate(d, arp.senderProt, arp.senderHard, TRUE);
                    }

                    if (arp.opcode == ARP_OPCODE_REQUEST)
                        success = Send(d, ARP_OPCODE_REPLY, arp.senderProt, arp.senderHard);
                    else
                        success = RequestReply(d, e->addrProt, e);
                }
            }
        }
        if (new != chain)
            ChainFree(new);
    }
    return success;
}





static BOOLEAN Send(ARP_DESCR *descr, WORD opcode,
                            BYTE *addrProt, BYTE *addrHard)
{
    ARP_FRAME arp;
    MAC_HDR   hdr;
    CHAIN     *chain;
    BOOLEAN   success = FALSE;

    arp.typeHard    = descr->iface->arp;
    arp.typeProt    = descr->prot;
    arp.lengthHard  = (BYTE)descr->iface->addrLength;
    arp.lengthProt  = (BYTE)descr->addrLength;
    arp.opcode      = opcode;
    arp.senderHard  = descr->iface->addr;
    arp.senderProt  = descr->addr;
    arp.targetProt  = addrProt;
    
    hdr.src         = descr->addr;
    hdr.type        = MAC_PROT_ARP;
    
    if (opcode == ARP_OPCODE_REQUEST)
    {
        arp.targetHard  = descr->iface->addrBroadcast;
        hdr.dst         = descr->iface->addrBroadcast;
    }
    else
    {
        arp.targetHard  = addrHard;
        hdr.dst         = addrHard;
    }

    chain = FrameEncode(0, &arp);
    if (chain!=0)
    {
        success = MacSend(descr->iface, chain, &hdr);
        ChainFree(chain);
    }

    return success;
}




static CHAIN *FrameEncode(CHAIN *chain, ARP_FRAME *arp)
{
    FRAME *f;
    BYTE  *p;

    f = (FRAME *)ChainPush(&chain,
        sizeof(FRAME)+ 2*arp->lengthProt+ 2*arp->lengthHard);
    if (f == 0)
        return 0;
    f->typeHard   = IpH2NWord(arp->typeHard);
    f->typeProt   = IpH2NWord(arp->typeProt);
    f->lengthHard = arp->lengthHard;
    f->lengthProt = arp->lengthProt;
    f->opcode     = IpH2NWord(arp->opcode);
    p = (BYTE *)f + sizeof(FRAME);

    memcpy(p, arp->senderHard, arp->lengthHard);
    p+=arp->lengthHard;
    memcpy(p, arp->senderProt, arp->lengthProt);
    p+=arp->lengthProt;
    memcpy(p, arp->targetHard, arp->lengthHard);
    p+=arp->lengthHard;
    memcpy(p, arp->targetProt, arp->lengthProt);
    
    return chain;
}



static CHAIN *FrameDecode(CHAIN *chain, ARP_FRAME *arp)
{
    FRAME   *f;
    BYTE    *p;

    f = (FRAME *)ChainPop(&chain, sizeof(FRAME));
    if (f == 0)
        return 0;
    arp->typeHard   = IpN2HWord(f->typeHard);   
    arp->typeProt   = IpN2HWord(f->typeProt);   
    arp->lengthHard = f->lengthHard; 
    arp->lengthProt = f->lengthProt; 
    arp->opcode     = IpH2NWord(f->opcode);
    
    p = (BYTE *)ChainPop(&chain, 2*arp->lengthProt+ 2*arp->lengthHard);
    arp->senderHard = p;
    p+=arp->lengthHard;
    arp->senderProt = p;
    p+=arp->lengthProt;
    arp->targetHard = p;
    p+=arp->lengthHard;
    arp->targetProt = p;
    
    return chain;
}
