//=====================================================================
//
//  malloc.cpp
//
//  memory allocation routines
//
//  These are clunky and slow, but they should work.
//
//  Protected Mode version
//
//  Copyright (c) 1994, Kevin Morgan, All rights reserved.
//
//=====================================================================
#define PROTECT

#include <stdio.h>
#include <alloc.h>
#include <dos.h>
#include <string.h>

#include <stdlib.h>
#include <stdarg.h>
#include <conio.h>

#ifdef PROTECT
#define MHUGE far
#define INITIAL_CHUNK_NODES 2000
#include "dpmish.h"
#else
#define INITIAL_CHUNK_NODES 700
#define MHUGE huge
#endif

//#define DEBUG

#define pause()

//=====================================================================
//
// _heaplen
//
// can be pre-initialized to specify how much heap to allocate
// initially.
//
//=====================================================================
extern unsigned long _protheaplen;

class MemoryChunk {
	public:
		unsigned long ptr;
		unsigned long sz;
		MemoryChunk *next;
		MemoryChunk *prev;
		int inuse;
};

MemoryChunk *realMem = 0;
MemoryChunk *rootMem = 0;
MemoryChunk *freeNodes=0;		// unused memory nodes
unsigned long memoryLeft = 0;
int heapErrors = 0;

extern "C" void CreateHeap(void);
void checkHeap(char *);
void CreateChunkNodes();

void *makeSegPtr(unsigned long addr)
{
#ifdef PROTECT
    return (void *) addr;
#else
	unsigned seg = addr>>4;
	unsigned offs = addr&15;
	return MK_FP(seg,offs);
#endif
}

#pragma argsused
void fillHeap(MemoryChunk *p)
{
#ifdef DEBUG
    unsigned MHUGE *q = (unsigned MHUGE *) makeSegPtr(p->ptr);
    unsigned long i;
    for (i=0;i<8;i++)
        *q++ = 0xbad;
#ifdef BRUTAL
    if (p->inuse==0)
        for (i=8;i<p->sz/2;i++) {
            *q++ = 0xbad;
        }
#endif
#endif    
}

void freeChunkNode(MemoryChunk *p)
{
	p->next = freeNodes;
	freeNodes = p;
}

//extern "C" void debug_out(char *, int, int);

#pragma argsused
void mprintf(char *fmt, ...)
{
#ifdef DEBUG
	va_list argptr;
	unsigned count;
	char buf[80];
	va_start(argptr, fmt);
	vsprintf(buf, fmt, argptr);
	va_end(argptr);
	_dos_write(1, buf, strlen(buf), &count);
//    debug_out(buf, 24*80, 80);
#endif
}


MemoryChunk * getChunkNode(unsigned long segp, unsigned long sz, int inUse)
{
    if (!freeNodes)
        CreateChunkNodes();
	MemoryChunk *p = freeNodes;
	if (p) {
		freeNodes = freeNodes->next;
		p->ptr = segp;
		p->sz = sz;
		p->inuse = inUse;
		p->next = p;
		p->prev = p;
		mprintf("getChunkNode at(%08lx) for %08lx\r\n", segp, sz);

	}
	return p;
}


MemoryChunk *findChunk(void *ptr)
{
#ifdef PROTECT
    unsigned long addr = (unsigned long) ptr;
#else
	unsigned long addr = long(FP_SEG(ptr))*16 + FP_OFF(ptr);
#endif

#ifdef DEBUG
    addr -= 16;
#endif
	// find a free memory chunk
	MemoryChunk *m = realMem;
	for (;;) {
		if ( m->inuse!=0 && m->ptr==addr ) {
			return m;
		}
		m = m->next;
		if (m==realMem) {
			return 0;	// could not find
		}
	}
}

void insertChunk(MemoryChunk *& root, MemoryChunk *leftover)
{
    if (root==0) {
        root = leftover;
        root->next = root;
        root->prev = root;
    }
    else {
        MemoryChunk *m = root;
        // find m where:
        //  m->ptr < leftover->ptr  < m->next->ptr
        for (;;) {
            if (leftover->ptr < m->next->ptr) break;
            if (m->next == root) break;
            m = m->next;
        }

    	MemoryChunk *nextChunk = m->next;
    	leftover->next = nextChunk;
    	m->next = leftover;
	    nextChunk->prev = leftover;
    	leftover->prev = m;

        if (leftover->ptr < root->ptr)
            root = leftover;
    }
}

MemoryChunk *returnExcess(MemoryChunk *m, unsigned long resizeTo)
{
	MemoryChunk *leftover = getChunkNode(m->ptr+resizeTo, m->sz-resizeTo, 0);

	if (leftover==0) {
		return 0;	// couldn't allocate a descriptor
	}

	// carve up this chunk
	{
		m->sz = resizeTo;
		MemoryChunk *nextChunk = m->next;
		leftover->next = m->next;
		m->next = leftover;
		nextChunk->prev = leftover;
		leftover->prev = m;
		realMem = leftover;
        fillHeap(leftover);
	}
	return m;
}

#ifdef PROTECT
void CreateChunkNodes()
{
    unsigned selector;

	freeNodes = 0;
	unsigned nodeSz = INITIAL_CHUNK_NODES * sizeof(MemoryChunk);
	nodeSz = (nodeSz+15);
	nodeSz = nodeSz - (nodeSz&15);
    if (Dpmi.getMappedMemory(nodeSz, selector, AccessRightsData)!=DPMI_OK) {
        mprintf("cannot get memory for MemoryChunks\n");
        return;
    }

	mprintf("freeNodes at(%04x,%04x) for %04x\r\n", selector,0, nodeSz);

	MemoryChunk *freeArray = (MemoryChunk *) MK_FP(selector, 0);
	int i;
	for (i=0;i<INITIAL_CHUNK_NODES;i++)
		freeChunkNode( &freeArray[i] );
}

void CreateHeap(void)
{
    mprintf("create heap\n");
    freeNodes = 0;
    CreateChunkNodes();
	int i;
	realMem = 0;
    while (_protheaplen>0) {
		unsigned selector;
		long sz = 0x10000l;
        if (sz>_protheaplen) sz = _protheaplen;
		if (Dpmi.getMappedMemory(sz, selector, AccessRightsData)!=DPMI_OK) {
			mprintf("cannot get memory\n");
			return;
		}
        MemoryChunk *chunk = getChunkNode( long(selector)<<16, sz, 0);
        fillHeap(chunk);
		insertChunk(rootMem, chunk);
        memoryLeft += sz;
        _protheaplen -= sz;
	}
	realMem = rootMem;
}
#else
void CreateChunkNodes(void)
{
}

void CreateHeap(void)
{
	unsigned segp;
	unsigned sz = 64000;
	if (_dos_allocmem(sz, &segp)!=0) {
		sz = segp;
		if (_dos_allocmem(sz, &segp)!=0)
			return;
	}
    
	{
		freeNodes = 0;
		unsigned nodeSz = INITIAL_CHUNK_NODES * sizeof(MemoryChunk);
		nodeSz = (nodeSz+15);
		nodeSz = nodeSz - (nodeSz&15);

		mprintf("freeNodes at(%04x,%04x) for %04x\r\n", segp,0, nodeSz);

		MemoryChunk *freeArray = (MemoryChunk *) MK_FP(segp, 0);
		segp += (nodeSz>>4);
		sz -= (nodeSz>>4);
		int i;
		for (i=0;i<INITIAL_CHUNK_NODES;i++)
			freeChunkNode( &freeArray[i] );
	}

	realMem = getChunkNode( long(segp)<<4, long(sz)<<4, 0);
    rootMem = realMem;
	memoryLeft += (long(sz)<<4);
    fillHeap(rootMem);
    checkHeap("CreatHeap");
}
#endif

#pragma startup CreateHeap 4


unsigned long roundUp(unsigned long sz)
{
    unsigned long roundedSz = sz+15;
#ifdef DEBUG
    roundedSz += 16;
#endif
	return roundedSz - (roundedSz&15);
}

void *farmalloc(unsigned long sz)
{
	if (!realMem) return 0;
	mprintf("malloc(%04x)\r\n", sz);
	unsigned long roundedSz = roundUp(sz);

	// find a free memory chunk
	MemoryChunk *m = realMem;
	for (;;) {
		if (m->inuse==0 && m->sz>=roundedSz)
			break;
		m = m->next;
		if (m==realMem) {
			mprintf("malloc(%04lx) no more memory %04lx left\r\n", sz, memoryLeft);
			return 0;	// could not allocate
		}
	}

	if (roundedSz<m->sz) {	// get remainder node
	    if (!returnExcess(m, roundedSz)) {
		    mprintf("malloc(%04x) no more descriptors\r\n", sz);
		    return 0;
	    }
    }

	// note that all pointers have a zero offset.
	m->inuse = 1;
	memoryLeft -= roundedSz;
	mprintf("malloc(%04lx, %04lx <- %04lx) %04lx left\r\n", m->ptr, roundedSz, sz, memoryLeft);
#ifdef DEBUG
    fillHeap(m);
    checkHeap("after malloc");
	return makeSegPtr(m->ptr+16);
#else
	return makeSegPtr(m->ptr);
#endif
}

void *malloc(size_t sz)
{
	return farmalloc(sz);
}

void *farrealloc(void *p, unsigned long sz)
{
	if (!p) return farmalloc(sz);
	unsigned long roundSz = roundUp(sz);
	MemoryChunk *oldChunk = findChunk(p);
	if (oldChunk->sz==roundSz)
		return p;
#ifdef BAD
	if (oldChunk->sz<roundSz) {	// shrink region
		if (!returnExcess(oldChunk, roundSz)) {
			mprintf("farrealloc(%04x) no more descriptors\r\n", sz);
			return 0;
		}
		return p;
	}
#endif
	void *newRegion = farmalloc(sz);
	if (!newRegion) return 0;
	memcpy(newRegion, p, sz);
	farfree(p);
	return newRegion;
}

void *realloc(void *p, size_t sz)
{
	return farrealloc(p,sz);
}

#pragma argsused
void traceback(unsigned bp)
{
#ifndef PROTECT
	int i;
	extern unsigned _psp;

	mprintf("    cs=%04x ss=%04x psp=%04x\r\n", _CS, _SS, _psp );
	for (i=0;i<5;i++) {
		unsigned *p = (unsigned *) MK_FP(_SS,bp);
		mprintf("    return to bp=%04x , pc=%04x:%04x (%04x:%04x rel)\r\n        ", p[0], p[2], p[1], p[2]-_psp-0x10, p[1] );
		int j;
		for (j=0;j<5;j++)
			mprintf(" %04x", p[j]);
		mprintf("\r\n");
		if (p[0]<=bp) break;
		bp = p[0];
	}
#endif
}

int mergeChunks(MemoryChunk *m)
{
	// combine with upper region
	MemoryChunk *nextChunk = m->next;
	if (m->inuse==0 && nextChunk->inuse==0 && nextChunk->ptr == m->ptr+m->sz) {
		m->next = nextChunk->next;
		m->next->prev = m;
		m->sz += nextChunk->sz;
		if (nextChunk==realMem) realMem = m;
		freeChunkNode(nextChunk);
        return 1;
	}
    return 0;
}

void farfree(void far *ptr)
{
	if (!ptr) return;

	MemoryChunk *m = findChunk(ptr);
	if (!m) {
		mprintf("free(%04x:%04x) block not allocated free\r\n", FP_SEG(ptr), FP_OFF(ptr) );
		traceback(_BP);
		return;
	}
	m->inuse = 0;
	unsigned long sz = m->sz;
	memoryLeft += sz;

    MemoryChunk *prev = m->prev;
    mergeChunks(m);
	if (mergeChunks(prev)) m = prev;
#ifdef DEBUG
    fillHeap(m);
    checkHeap("after free");
//	mprintf("free(%04x:%04x,%04lx)\r\n", FP_SEG(ptr), FP_OFF(ptr), sz );
#endif
}

void free(void *p)
{
	farfree(p);
}

int farheapcheck1(void)
{
    MemoryChunk *p = rootMem;
    for (;;) {
        MemoryChunk *next = p->next;
#ifndef PROTECT
        if (next!=rootMem) {
            unsigned long nextptr = p->ptr + p->sz;
            if (nextptr!=next->ptr) {
                mprintf("MemoryChunk list corrupt! %04lx + %04lx != %04lx %04lx\r\n",
                        p->ptr, p->sz, nextptr, next->ptr);
                return _HEAPCORRUPT;
            }
        }
#endif

#ifdef DEBUG
        unsigned MHUGE *q = (unsigned MHUGE *) makeSegPtr(p->ptr);
        unsigned long i;
        for (i=0;i<8;i++)
            if ((*q++) != 0xbad) {
                mprintf("Memory at %06lx corrupt\r\n", p->ptr);
                return _HEAPCORRUPT;
            }
#ifdef BRUTAL
        if (p->inuse==0)
            for (i=8;i<p->sz/2;i++)
                if ((*q++) != 0xbad) {
                    mprintf("Memory at %06lx corrupt\r\n", p->ptr);
                    return _HEAPCORRUPT;
                }
#endif
#endif
        p = next;
        if (p==rootMem) break;
    }
	return _HEAPOK;
}

int farheapcheck(void)
{
	return _HEAPOK;
}

int heapcheck(void)
{
	return farheapcheck();
}

unsigned long coreleft()
{
	return memoryLeft;
}

unsigned long farcoreleft()
{
	return memoryLeft;
}

void checkHeap(char *where)
{
	if (heapErrors<20) {
        if (farheapcheck1()!=_HEAPOK) {
            mprintf("heap corrupt from %s\n", where);
            heapErrors++;
        }
    }
}


