/************************************************************************************

Filename    :   CAPI_D3D1X_Util.cpp
Content     :   D3DX10 utility classes for rendering
Created     :   September 10, 2012
Authors     :   Andrew Reisse

Copyright   :   Copyright 2014 Oculus VR, LLC All Rights reserved.

Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License"); 
you may not use the Oculus VR Rift SDK except in compliance with the License, 
which is provided at the time of installation or download, or which 
otherwise accompanies this software in either electronic or hard copy form.

You may obtain a copy of the License at

http://www.oculusvr.com/licenses/LICENSE-3.2 

Unless required by applicable law or agreed to in writing, the Oculus VR SDK 
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

************************************************************************************/

#include "CAPI_D3D1X_Util.h"

#include <d3dcompiler.h>

namespace OVR { namespace CAPI { namespace D3D_NS {


//-------------------------------------------------------------------------------------
// ***** ShaderFill

void ShaderFill::Set(PrimitiveType prim) const
{
    Shaders->Set(prim);

	for(int i = 0; i < 8; ++i)
    {
        if ( VsTextures[i] != NULL )
        {
		    VsTextures[i]->Set(i, Shader_Vertex);
        }
    }

	for(int i = 0; i < 8; ++i)
    {
        if ( CsTextures[i] != NULL )
        {
		    CsTextures[i]->Set(i, Shader_Compute);
        }
    }

	for(int i = 0; i < 8; ++i)
    {
        if ( PsTextures[i] != NULL )
        {
		    PsTextures[i]->Set(i, Shader_Fragment);
        }
    }
}


//-------------------------------------------------------------------------------------
// ***** Buffer

Buffer::~Buffer()
{
}

bool Buffer::Data(int use, const void *buffer, size_t size, int computeBufferStride /*=-1*/)
{
    if (D3DBuffer && Size >= size)
    {
        if (Dynamic)
        {
            if (!buffer)
                return true;

            void* v = Map(0, size, Map_Discard);
            if (v)
            {
                memcpy(v, buffer, size);
                Unmap(v);
                return true;
            }
        }
        else
        {
            OVR_ASSERT (!(use & Buffer_ReadOnly));
            pParams->pContext->UpdateSubresource(D3DBuffer, 0, NULL, buffer, 0, 0);
            return true;
        }
    }
    if (D3DBuffer)
    {
        D3DBuffer = NULL;
        Size = 0;
        Use = 0;
        Dynamic = false;
    }
    D3DSrv = NULL;
#if (OVR_D3D_VERSION >= 11)
    D3DUav = NULL;
#endif

    D3D1X_(BUFFER_DESC) desc;
    memset(&desc, 0, sizeof(desc));
    if (use & Buffer_ReadOnly)
    {
        desc.Usage = D3D1X_(USAGE_IMMUTABLE);
        desc.CPUAccessFlags = 0;
    }
    else
    {
        desc.Usage = D3D1X_(USAGE_DYNAMIC);
        desc.CPUAccessFlags = D3D1X_(CPU_ACCESS_WRITE);
        Dynamic = true;
    }

    switch(use & Buffer_TypeMask)
    {
    case Buffer_Vertex:  desc.BindFlags = D3D1X_(BIND_VERTEX_BUFFER); break;
    case Buffer_Index:   desc.BindFlags = D3D1X_(BIND_INDEX_BUFFER);  break;
    case Buffer_Uniform:
        desc.BindFlags = D3D1X_(BIND_CONSTANT_BUFFER);
        size = ((size + 15) & ~15);
        break;
    case Buffer_Compute:
#if (OVR_D3D_VERSION >= 11)
        // There's actually a bunch of options for buffers bound to a CS.
        // Right now this is the most appropriate general-purpose one. Add more as needed.

        // NOTE - if you want D3D1X_(CPU_ACCESS_WRITE), it MUST be either D3D1X_(USAGE_DYNAMIC) or D3D1X_(USAGE_STAGING).
        // TODO: we want a resource that is rarely written to, in which case we'd need two surfaces - one a STAGING
        // that the CPU writes to, and one a DEFAULT, and we CopyResource from one to the other. Hassle!
        // Setting it as D3D1X_(USAGE_DYNAMIC) will get the job done for now.
        // Also for fun - you can't have a D3D1X_(USAGE_DYNAMIC) buffer that is also a D3D1X_(BIND_UNORDERED_ACCESS).
        OVR_ASSERT ( !(use & Buffer_ReadOnly) );
        desc.BindFlags = D3D1X_(BIND_SHADER_RESOURCE);
        desc.Usage     = D3D1X_(USAGE_DYNAMIC);
        desc.MiscFlags = D3D1X_(RESOURCE_MISC_BUFFER_STRUCTURED);
        desc.CPUAccessFlags = D3D1X_(CPU_ACCESS_WRITE);
        OVR_ASSERT ( computeBufferStride > 0 );
        desc.StructureByteStride = computeBufferStride; // sizeof(DistortionComputePin);

        Dynamic = true;
        size = ((size + 15) & ~15);
#else
        OVR_UNUSED ( computeBufferStride );
        OVR_ASSERT ( false );  // No compute shaders in DX10
#endif
        break;
    }

    desc.ByteWidth = (unsigned)size;

    D3D1X_(SUBRESOURCE_DATA) sr;
    sr.pSysMem = buffer;
    sr.SysMemPitch = 0;
    sr.SysMemSlicePitch = 0;

    D3DBuffer = NULL;
    HRESULT hr = pParams->pDevice->CreateBuffer(&desc, buffer ? &sr : NULL, &D3DBuffer.GetRawRef());
    if (SUCCEEDED(hr))
    {
        Use = use;
        Size = desc.ByteWidth;
    }
    else
    {
        OVR_ASSERT ( false );
        return false;
    }

    if ( ( use & Buffer_TypeMask ) == Buffer_Compute )
    {
        HRESULT hres = pParams->pDevice->CreateShaderResourceView ( D3DBuffer, NULL, &D3DSrv.GetRawRef() );
        if ( SUCCEEDED(hres) )
        {
#if (OVR_D3D_VERSION >= 11)
#if 0           // Right now we do NOT ask for UAV access (see flags above).
            hres = Ren->Device->CreateUnorderedAccessView ( D3DBuffer, NULL, &D3DUav.GetRawRef() );
            if ( SUCCEEDED(hres) )
            {
                // All went well.
            }
#endif
#endif
        }

        if ( !SUCCEEDED(hres) )
        {
            OVR_ASSERT ( false );
            Use = 0;
            Size = 0;
            return false;
        }
    }

    return true;

}

void*  Buffer::Map(size_t start, size_t size, int flags)
{
    OVR_UNUSED(size);

    D3D1X_(MAP) mapFlags = D3D1X_(MAP_WRITE);
    if (flags & Map_Discard)    
        mapFlags = D3D1X_(MAP_WRITE_DISCARD);    
    if (flags & Map_Unsynchronized)    
        mapFlags = D3D1X_(MAP_WRITE_NO_OVERWRITE);

#if (OVR_D3D_VERSION == 10)
    void* map;
    if (SUCCEEDED(D3DBuffer->Map(mapFlags, 0, &map)))    
        return ((char*)map) + start;    
#else
    D3D11_MAPPED_SUBRESOURCE map;
    if (SUCCEEDED(pParams->pContext->Map(D3DBuffer, 0, mapFlags, 0, &map)))
        return ((char*)map.pData) + start;
#endif

    return NULL;
}

bool   Buffer::Unmap(void *m)
{
    OVR_UNUSED(m);

    D3DSELECT_10_11( D3DBuffer->Unmap(),
                     pParams->pContext->Unmap(D3DBuffer, 0) );
    return true;
}


//-------------------------------------------------------------------------------------
// Shaders

template<> bool ShaderImpl<Shader_Vertex, ID3D1xVertexShader>::Load(void* shader, size_t size)
{
    return SUCCEEDED(pParams->pDevice->CreateVertexShader(shader, size  D3D11_COMMA_0, &D3DShader));
}
template<> bool ShaderImpl<Shader_Pixel, ID3D1xPixelShader>::Load(void* shader, size_t size)
{
    return SUCCEEDED(pParams->pDevice->CreatePixelShader(shader, size  D3D11_COMMA_0, &D3DShader));
}
#if (OVR_D3D_VERSION>=11)
template<> bool ShaderImpl<Shader_Compute, ID3D1xComputeShader>::Load(void* shader, size_t size)
{
    return SUCCEEDED(pParams->pDevice->CreateComputeShader(shader, size  D3D11_COMMA_0, &D3DShader));
}
#endif

template<> void ShaderImpl<Shader_Vertex, ID3D1xVertexShader>::Set(PrimitiveType) const
{
    pParams->pContext->VSSetShader(D3DShader D3D11_COMMA_0 D3D11_COMMA_0 );
}
template<> void ShaderImpl<Shader_Pixel, ID3D1xPixelShader>::Set(PrimitiveType) const
{
    pParams->pContext->PSSetShader(D3DShader D3D11_COMMA_0 D3D11_COMMA_0 ) ;
}
#if (OVR_D3D_VERSION>=11)
template<> void ShaderImpl<Shader_Compute, ID3D1xComputeShader>::Set(PrimitiveType) const
{
    pParams->pContext->CSSetShader(D3DShader D3D11_COMMA_0 D3D11_COMMA_0 ) ;
}
#endif

template<> void ShaderImpl<Shader_Vertex, ID3D1xVertexShader>::SetUniformBuffer(Buffer* buffer, int i)
{
    pParams->pContext->VSSetConstantBuffers(i, 1, &((Buffer*)buffer)->D3DBuffer.GetRawRef());
}
template<> void ShaderImpl<Shader_Pixel, ID3D1xPixelShader>::SetUniformBuffer(Buffer* buffer, int i)
{
    pParams->pContext->PSSetConstantBuffers(i, 1, &((Buffer*)buffer)->D3DBuffer.GetRawRef());
}
#if (OVR_D3D_VERSION>=11)
template<> void ShaderImpl<Shader_Compute, ID3D1xComputeShader>::SetUniformBuffer(Buffer* buffer, int i)
{
    pParams->pContext->CSSetConstantBuffers(i, 1, &((Buffer*)buffer)->D3DBuffer.GetRawRef());
}
#endif

//-------------------------------------------------------------------------------------
// ***** Shader Base

ShaderBase::ShaderBase(RenderParams* rp, ShaderStage stage) :
    Shader(stage),
    pParams(rp),
    UniformData(NULL),
    UniformsSize(0),
    UniformRefl(NULL),
    UniformReflSize(0)
{
}

ShaderBase::~ShaderBase()
{
    if (UniformData)
    {
        OVR_FREE(UniformData);
        UniformData = NULL;
    }

    // UniformRefl does not need to be freed
    UniformRefl = NULL;
}

bool ShaderBase::SetUniform(const char* name, int n, const float* v)
{
    for(unsigned i = 0; i < UniformReflSize; i++)
    {
        if (!strcmp(UniformRefl[i].Name, name))
        {
            memcpy(UniformData + UniformRefl[i].Offset, v, n * sizeof(float));
            return 1;
        }
    }
    return 0;
}

bool ShaderBase::SetUniformBool(const char* name, int n, const bool* v) 
{
    OVR_UNUSED(n);
    for(unsigned i = 0; i < UniformReflSize; i++)
    {
        if (!strcmp(UniformRefl[i].Name, name))
        {
            memcpy(UniformData + UniformRefl[i].Offset, v, UniformRefl[i].Size);
            return 1;
        }
    }
    return 0;
}

void ShaderBase::InitUniforms(const Uniform* refl, size_t reflSize)
{
    if(!refl)
    {
        UniformRefl = NULL;
        UniformReflSize = 0;

        UniformsSize = 0;
        if (UniformData)
        {
            OVR_FREE(UniformData);
            UniformData = 0;
        }
        return; // no reflection data
    }

    UniformRefl = refl;
    UniformReflSize = reflSize;
    
    UniformsSize = UniformRefl[UniformReflSize-1].Offset + UniformRefl[UniformReflSize-1].Size;
    UniformData = (unsigned char*)OVR_ALLOC(UniformsSize);
}

void ShaderBase::UpdateBuffer(Buffer* buf)
{
    if (UniformsSize)
    {
        buf->Data(Buffer_Uniform, UniformData, UniformsSize);
    }
}


//-------------------------------------------------------------------------------------
// ***** Texture
// 
Texture::Texture(RenderParams* rp, int fmt, const Sizei texSize,
                 ID3D1xSamplerState* sampler, int samples)
    : pParams(rp), Tex(NULL), TexSv(NULL), TexRtv(NULL), TexDsv(NULL),
    TextureSize(texSize),
    Sampler(sampler),
    Samples(samples)
{
    OVR_UNUSED(fmt);    
}

Texture::~Texture()
{
}

void Texture::Set(int slot, ShaderStage stage) const
{    
    ID3D1xShaderResourceView* texSv = TexSv.GetPtr();

    switch(stage)
    {
    case Shader_Fragment:
        pParams->pContext->PSSetShaderResources(slot, 1, &texSv);
        pParams->pContext->PSSetSamplers(slot, 1, &Sampler.GetRawRef());        
        break;

    case Shader_Vertex:
        pParams->pContext->VSSetShaderResources(slot, 1, &texSv);
        pParams->pContext->VSSetSamplers(slot, 1, &Sampler.GetRawRef());
        break;

#if (OVR_D3D_VERSION >= 11)
    case Shader_Compute:
        pParams->pContext->CSSetShaderResources(slot, 1, &texSv);
        pParams->pContext->CSSetSamplers(slot, 1, &Sampler.GetRawRef());
        break;
#endif
    default: OVR_ASSERT ( false ); break;
    }
}


//-------------------------------------------------------------------------------------
// ***** GpuTimer
// 
#if (OVR_D3D_VERSION == 11)
#define D3DQUERY_EXEC(_context_, _query_, _command_, ...)  _context_->_command_(_query_, __VA_ARGS__)
#else
#define D3DQUERY_EXEC(_context_, _query_, _command_, ...)  _query_->_command_(__VA_ARGS__)
#endif


void GpuTimer::Init(ID3D1xDevice* device, ID3D1xDeviceContext* content)
{
    D3dDevice = device;
    Context = content;    
}

void GpuTimer::BeginQuery()
{
    if(GotoNextFrame(LastQueuedFrame) == LastTimedFrame)
    {
        OVR_ASSERT(false); // too many queries queued
        return;
    }

    LastQueuedFrame = GotoNextFrame(LastQueuedFrame);

    GpuQuerySets& newQuerySet = QuerySets[LastQueuedFrame];
    if(newQuerySet.DisjointQuery == NULL)
    {
        // Create the queries
        D3D1x_QUERY_DESC desc;
        desc.Query = D3D1X_(QUERY_TIMESTAMP_DISJOINT);
        desc.MiscFlags = 0;
        VERIFY_HRESULT(D3dDevice->CreateQuery(&desc, &newQuerySet.DisjointQuery));

        desc.Query = D3D1X_(QUERY_TIMESTAMP);
        VERIFY_HRESULT(D3dDevice->CreateQuery(&desc, &newQuerySet.TimeStartQuery));
        VERIFY_HRESULT(D3dDevice->CreateQuery(&desc, &newQuerySet.TimeEndQuery));
    }

    OVR_ASSERT(!newQuerySet.QueryStarted);
    OVR_ASSERT(!newQuerySet.QueryAwaitingTiming);

    
    D3DQUERY_EXEC(Context, QuerySets[LastQueuedFrame].DisjointQuery, Begin, );  // First start a disjoint query
    D3DQUERY_EXEC(Context, QuerySets[LastQueuedFrame].TimeStartQuery, End, );   // Insert start timestamp
    
    newQuerySet.QueryStarted = true;
    newQuerySet.QueryAwaitingTiming = false;
    //newQuerySet.QueryTimed = false;
}

void GpuTimer::EndQuery()
{
    if(LastQueuedFrame > 0 && !QuerySets[LastQueuedFrame].QueryStarted)
        return;

    GpuQuerySets& doneQuerySet = QuerySets[LastQueuedFrame];
    OVR_ASSERT(doneQuerySet.QueryStarted);
    OVR_ASSERT(!doneQuerySet.QueryAwaitingTiming);

    // Insert the end timestamp
    D3DQUERY_EXEC(Context, doneQuerySet.TimeEndQuery, End, );

    // End the disjoint query
    D3DQUERY_EXEC(Context, doneQuerySet.DisjointQuery, End, );

    doneQuerySet.QueryStarted = false;
    doneQuerySet.QueryAwaitingTiming = true;
}

float GpuTimer::GetTiming(bool blockUntilValid)
{
    float time = -1.0f;

    // loop until we hit a query that is not ready yet, or we have read all queued queries
    while(LastTimedFrame != LastQueuedFrame)
    {
        int timeTestFrame = GotoNextFrame(LastTimedFrame);

        GpuQuerySets& querySet = QuerySets[timeTestFrame];

        OVR_ASSERT(!querySet.QueryStarted && querySet.QueryAwaitingTiming);

        UINT64 startTime = 0;
        UINT64 endTime = 0;
        D3D1X_(QUERY_DATA_TIMESTAMP_DISJOINT) disjointData;

        if(blockUntilValid)
        {
            while(D3DQUERY_EXEC(Context, querySet.TimeStartQuery, GetData, &startTime, sizeof(startTime), 0) != S_OK);
            while(D3DQUERY_EXEC(Context, querySet.TimeEndQuery, GetData, &endTime, sizeof(endTime), 0) != S_OK);
            while(D3DQUERY_EXEC(Context, querySet.DisjointQuery, GetData, &disjointData, sizeof(disjointData), 0) != S_OK);
        }
        else
        {
// Early return if we fail to get data for any of these
            if(D3DQUERY_EXEC(Context, querySet.TimeStartQuery, GetData, &startTime, sizeof(startTime), 0) != S_OK)    return time;
            if(D3DQUERY_EXEC(Context, querySet.TimeEndQuery, GetData, &endTime, sizeof(endTime), 0) != S_OK)          return time;
            if(D3DQUERY_EXEC(Context, querySet.DisjointQuery, GetData, &disjointData, sizeof(disjointData), 0) != S_OK)    return time;
        }

        querySet.QueryAwaitingTiming = false;
        LastTimedFrame = timeTestFrame; // successfully retrieved the timing data

        if(disjointData.Disjoint == false)
        {
            UINT64 delta = endTime - startTime;
            float frequency = (float)(disjointData.Frequency);
            time = (delta / frequency);
        }
    }
    
    return time;
}

}}} // OVR::CAPI::D3DX