//
// Simple driver that demonstrates mapping physical memory into
//   a user mode process's address space
//

#include "ntddk.h"
#include "..\mapmem.h"
#include "stdarg.h"



//
// The following was not included in the October release of the
//   preliminary NT DDK, and is necessary when utilizing the
//   ZwMapViewOfSection API.
//

#define SECTION_QUERY       0x0001
#define SECTION_MAP_WRITE   0x0002
#define SECTION_MAP_READ    0x0004
#define SECTION_MAP_EXECUTE 0x0008
#define SECTION_EXTEND_SIZE 0x0010

#define SECTION_ALL_ACCESS (STANDARD_RIGHTS_REQUIRED|SECTION_QUERY | \
                            SECTION_MAP_WRITE                      | \
                            SECTION_MAP_READ                       | \
                            SECTION_MAP_EXECUTE                    | \
                            SECTION_EXTEND_SIZE)

typedef enum _SECTION_INHERIT {
    ViewShare = 1,
    ViewUnmap = 2
} SECTION_INHERIT;

#define PAGE_NOACCESS          0x01     // winnt
#define PAGE_READONLY          0x02     // winnt
#define PAGE_READWRITE         0x04     // winnt
#define PAGE_WRITECOPY         0x08
#define PAGE_EXECUTE           0x10
#define PAGE_EXECUTE_READ      0x20
#define PAGE_EXECUTE_READWRITE 0x40
#define PAGE_EXECUTE_WRITECOPY 0x80
#define PAGE_GUARD            0x100
#define PAGE_NOCACHE          0x200

NTSTATUS
ZwMapViewOfSection(
    IN HANDLE SectionHandle,
    IN HANDLE ProcessHandle,
    IN OUT PVOID *BaseAddress,
    IN ULONG ZeroBits,
    IN ULONG CommitSize,
    IN OUT PLARGE_INTEGER SectionOffset OPTIONAL,
    IN OUT PULONG ViewSize,
    IN SECTION_INHERIT InheritDisposition,
    IN ULONG AllocationType,
    IN ULONG Protect
    );


//
// Function prototypes for this module
//

NTSTATUS
MapMemCreateClose(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    );

NTSTATUS
MapMemDispatch(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    );

NTSTATUS
MapMemMapTheMemory(
    IN PDEVICE_OBJECT DeviceObject,
    IN OUT PVOID      ioBuffer,
    IN ULONG          inputBufferLength,
    IN ULONG          outputBufferLength
    );

#if DBG
#define MapMemDebugPrint(arg) DbgPrint arg
#else
#define MapMemDebugPrint(arg)
#endif



//
// The code...
//

NTSTATUS
DriverEntry(
    IN PDRIVER_OBJECT DriverObject,
    IN PUNICODE_STRING RegistryPath
    )
{

    PDEVICE_OBJECT deviceObject = NULL;
    NTSTATUS       ntStatus;
    WCHAR          deviceNameBuffer[] = L"\\Device\\MapMem";
    UNICODE_STRING deviceNameUnicodeString;


    MapMemDebugPrint(("MapMem: entering DriverEntry\n"));


    RtlInitUnicodeString(&deviceNameUnicodeString,
                         deviceNameBuffer);

    //
    // Create the device object
    //

    ntStatus = IoCreateDevice(
                 DriverObject,
                 0,
                 &deviceNameUnicodeString,
                 FILE_DEVICE_UNKNOWN,
                 0,
                 TRUE,
                 &deviceObject
                 );


    if (NT_SUCCESS(ntStatus))
    {
        //
        // Create dispatch points for device control, create, close.
        //

        DriverObject->MajorFunction[IRP_MJ_CREATE]         =
        DriverObject->MajorFunction[IRP_MJ_CLOSE]          = MapMemCreateClose;
        DriverObject->MajorFunction[IRP_MJ_DEVICE_CONTROL] = MapMemDispatch;
    }

    return ntStatus;
}



NTSTATUS
MapMemCreateClose(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    )
{
    //
    // No need to do anything.
    //

    //
    // Fill these in before calling IoCompleteRequest.
    //
    // DON'T get cute and try to use the status field of
    // the irp in the return status.  That IRP IS GONE as
    // soon as you call IoCompleteRequest.
    //

    MapMemDebugPrint(("MapMemOpenClose: enter\n"));

    Irp->IoStatus.Status = STATUS_SUCCESS;
    Irp->IoStatus.Information = 0;

    IoCompleteRequest(
        Irp,
        IO_NO_INCREMENT
        );

    return STATUS_SUCCESS;

}



NTSTATUS
MapMemDispatch(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    )
{
    PIO_STACK_LOCATION irpStack;
    PVOID              ioBuffer;
    ULONG              inputBufferLength;
    ULONG              outputBufferLength;
    ULONG              ioControlCode;
    NTSTATUS           ntStatus;

    //
    // Init to default settings- we only expect 1 type of
    //     IOCTL to roll through here, all others an error.
    //

    Irp->IoStatus.Status      = STATUS_INVALID_PARAMETER;
    Irp->IoStatus.Information = 0;


    //
    // Get a pointer to the current location in the Irp. This is where
    //     the function codes and parameters are located.
    //

    irpStack = IoGetCurrentIrpStackLocation(Irp);


    //
    // Get the pointer to the input/output buffer and it's length
    //

    ioBuffer = Irp->AssociatedIrp.SystemBuffer;
    inputBufferLength =
        irpStack->Parameters.DeviceIoControl.InputBufferLength;
    outputBufferLength =
        irpStack->Parameters.DeviceIoControl.OutputBufferLength;


    switch (irpStack->MajorFunction)
    {

    case IRP_MJ_DEVICE_CONTROL:

        ioControlCode = irpStack->Parameters.DeviceIoControl.IoControlCode;

        switch (ioControlCode)
        {
        case IOCTL_MAPMEM_MAP_USER_PHYSICAL_MEMORY:

            Irp->IoStatus.Status = MapMemMapTheMemory (DeviceObject,
                                                       ioBuffer,
                                                       inputBufferLength,
                                                       outputBufferLength
                                                       );

            if (NT_SUCCESS(Irp->IoStatus.Status))
            {
                //
                // Success! Set the following to sizeof(PVOID) to
                //     indicate we're passing valid data back.
                //

                Irp->IoStatus.Information = sizeof(PVOID);

                MapMemDebugPrint(("MapMem: memory successfully mapped :)\n"));
            }

            else

                MapMemDebugPrint(("MapMem: memory map failed :(\n"));

            break;

        default:

            MapMemDebugPrint(("MapMem: unknown IRP_MJ_DEVICE_CONTROL\n"));


            break;

        } // switch (ioControlCode)

        break;
    }


    //
    // DON'T get cute and try to use the status field of
    // the irp in the return status.  That IRP IS GONE as
    // soon as you call IoCompleteRequest.
    //

    ntStatus = Irp->IoStatus.Status;

    IoCompleteRequest(Irp,
                      IO_NO_INCREMENT);


    //
    // We never have pending operation so always return the status code.
    //

    return ntStatus;
}



NTSTATUS
MapMemMapTheMemory(
    IN PDEVICE_OBJECT DeviceObject,
    IN OUT PVOID      ioBuffer,
    IN ULONG          inputBufferLength,
    IN ULONG          outputBufferLength
    )
{
    PPHYSICAL_MEMORY_INFO ppmi = (PPHYSICAL_MEMORY_INFO) ioBuffer;

    INTERFACE_TYPE     interfaceType;
    ULONG              busNumber;
    PHYSICAL_ADDRESS   physicalAddress;
    ULONG              length;
    UNICODE_STRING     physicalMemoryUnicodeString;
    OBJECT_ATTRIBUTES  objectAttributes;
    HANDLE             physicalMemoryHandle  = NULL;
    PVOID              PhysicalMemorySection = NULL;
    ULONG              inIoSpace, inIoSpace2;
    NTSTATUS           ntStatus;
    PHYSICAL_ADDRESS   physicalAddressBase;
    PHYSICAL_ADDRESS   physicalAddressEnd;
    PHYSICAL_ADDRESS   viewBase;
    PHYSICAL_ADDRESS   mappedLength;
    BOOLEAN            translateBaseAddress;
    BOOLEAN            translateEndAddress;
    PVOID              virtualAddress;



    if ( ( inputBufferLength  < sizeof (PHYSICAL_MEMORY_INFO) ) ||
         ( outputBufferLength < sizeof (PVOID) ) )
    {
       MapMemDebugPrint(("MapMem: Insufficient input or output buffer\n"));

       ntStatus = STATUS_INSUFFICIENT_RESOURCES;

       goto done;
    }

    interfaceType   = ppmi->interfaceType;
    busNumber       = ppmi->busNumber;
    physicalAddress = ppmi->physicalAddress;
    length          = ppmi->length;


    //
    // Get a pointer to physical memory...
    //
    // - Create the name
    // - Initialize the data to find the object
    // - Open a handle to the oject and check the status
    // - Get a pointer to the object
    // - Free the handle
    //

    RtlInitUnicodeString(&physicalMemoryUnicodeString,
                         L"\\Device\\PhysicalMemory");

    InitializeObjectAttributes (&objectAttributes,
                                &physicalMemoryUnicodeString,
                                OBJ_CASE_INSENSITIVE,
                                (HANDLE) NULL,
                                (PSECURITY_DESCRIPTOR) NULL);

    ntStatus = ZwOpenSection (&physicalMemoryHandle,
                              SECTION_ALL_ACCESS,
                              &objectAttributes);

    if (!NT_SUCCESS(ntStatus))
    {
        MapMemDebugPrint(("MapMem: ZwOpenSection failed\n"));

        goto done;
    }

    ntStatus = ObReferenceObjectByHandle (physicalMemoryHandle,
                                          SECTION_ALL_ACCESS,
                                          (POBJECT_TYPE) NULL,
                                          KernelMode,
                                          &PhysicalMemorySection,
                                          (POBJECT_HANDLE_INFORMATION) NULL);

    if (!NT_SUCCESS(ntStatus))
    {
        MapMemDebugPrint(("MapMem: ObReferenceObjectByHandle failed\n"));

        goto close_handle;
    }

    //
    // Have a second variable used for the second HalTranslateBusAddres call.
    //

    inIoSpace = inIoSpace2 = 0;


    //
    // Initialize the physical addresses that will be translated
    //

    physicalAddressEnd = RtlLargeIntegerAdd (physicalAddress,
                                             RtlConvertUlongToLargeInteger(
                                                 length));

    //
    // Translate the physical addresses.
    //

    translateBaseAddress =
        HalTranslateBusAddress (interfaceType,
                                busNumber,
                                physicalAddress,
                                &inIoSpace,
                                &physicalAddressBase);

    translateEndAddress =
        HalTranslateBusAddress (interfaceType,
                                busNumber,
                                physicalAddressEnd,
                                &inIoSpace2,
                                &physicalAddressEnd);

    if ( !(translateBaseAddress && translateEndAddress) )
    {
        MapMemDebugPrint(("MapMem: HalTranslateBusAddress failed\n"));

        ntStatus = STATUS_UNSUCCESSFUL;

        goto close_handle;
    }

    //
    // Calculate the length of the memory to be mapped
    //

    mappedLength = RtlLargeIntegerSubtract (physicalAddressEnd,
                                            physicalAddressBase);


    //
    // If the mappedlength is zero, somthing very weird happened in the HAL
    // since the Length was checked against zero.
    //

    if (mappedLength.LowPart == 0)
    {
        MapMemDebugPrint(("MapMem: mappedLength.LowPart == 0\n"));

        ntStatus = STATUS_UNSUCCESSFUL;

        goto close_handle;
    }

    length = mappedLength.LowPart;


    //
    // If the address is in io space, just return the address, otherwise
    // go through the mapping mechanism
    //

    if (inIoSpace)
    {
        *((PVOID *) ioBuffer) = physicalAddressBase.LowPart;
    }

    else
    {
        //
        // initialize view base that will receive the physical mapped
        // address after the MapViewOfSection call.
        //

        viewBase = physicalAddressBase;


        //
        // Let ZwMapViewOfSection pick an address
        //

        virtualAddress = NULL;



        //
        // Map the section
        //

        ntStatus = ZwMapViewOfSection (physicalMemoryHandle,
                                       (HANDLE) -1,
                                       &virtualAddress,
                                       0L,
                                       length,
                                       &viewBase,
                                       &length,
                                       ViewShare,
                                       0,
                                       PAGE_READWRITE | PAGE_NOCACHE);

        if (!NT_SUCCESS(ntStatus))
        {
            MapMemDebugPrint(("MapMem: ZwMapViewOfSection failed\n"));

            goto close_handle;
        }

        //
        // Mapping the section above rounded the physical address down to the
        // nearest 64 K boundary. Now return a virtual address that sits where
        // we wnat by adding in the offset from the beginning of the section.
        //


        (ULONG) virtualAddress += (ULONG)physicalAddressBase.LowPart -
                                  (ULONG)viewBase.LowPart;

        *((PVOID *) ioBuffer) = virtualAddress;

    }

    ntStatus = STATUS_SUCCESS;



close_handle:

    ZwClose(physicalMemoryHandle);



done:

    MapMemDebugPrint(("\t ntStatus = %x\n"));
    return ntStatus;
}
