/* osgCompute - Copyright (C) 2008-2009 SVT Group
 *
 * This library is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 3 of
 * the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesse General Public License for more details.
 *
 * The full license is in LICENSE file included with this distribution.
*/

#ifndef OSGCUDA_ARRAY
#define OSGCUDA_ARRAY 1

#include <memory.h>
#include <cuda_runtime.h>
#include <driver_types.h>
#include <osg/Image>
#include <osgCompute/Buffer>
#include "osgCuda/Context"
#include "osgCuda/Export"

namespace osg
{
    class Vec2b;
    class Vec3b;
    class Vec4b;
    class Vec4ub;
    class Vec2s;
    class Vec3s;
    class Vec4s;
    class Vec2f;
    class Vec3f;
    class Vec4f;
    class Vec2d;
    class Vec3d;
    class Vec4d;
}

namespace osgCuda
{
    template<class T>
    class Array;

    typedef Array<unsigned char>     UByteArray;
    typedef Array<osg::Vec4ub>       Vec4ubArray;
    typedef Array<char>              ByteArray;
    typedef Array<osg::Vec2b>        Vec2bArray;
    typedef Array<osg::Vec3b>        Vec3bArray;
    typedef Array<osg::Vec4b>        Vec4bArray;
    typedef Array<unsigned short>    UShortArray;
    typedef Array<short>             ShortArray;
    typedef Array<osg::Vec2s>        Vec2sArray;
    typedef Array<osg::Vec3s>        Vec3sArray;
    typedef Array<osg::Vec4s>        Vec4sArray;
    typedef Array<unsigned int>      UIntArray;
    typedef Array<int>               IntArray;
    typedef Array<unsigned long>     ULongArray;
    typedef Array<long>              LongArray;
    typedef Array<float>             FloatArray;
    typedef Array<osg::Vec2f>        Vec2fArray;
    typedef Array<osg::Vec3f>        Vec3fArray;
    typedef Array<osg::Vec4f>        Vec4fArray;
    typedef Array<double>            DoubleArray;
    typedef Array<osg::Vec2d>        Vec2dArray;
    typedef Array<osg::Vec3d>        Vec3dArray;
    typedef Array<osg::Vec4d>        Vec4dArray;

    typedef std::vector< osg::ref_ptr<osg::Image> >                     StreamImageList;
    typedef std::vector< osg::ref_ptr<osg::Image> >::iterator           StreamImageListItr;
    typedef std::vector< osg::ref_ptr<osg::Image> >::const_iterator     StreamImageListCnstItr;

    enum ALLOC_HINTS
    {
        ALLOC_LINEAR = 0x0,
        ALLOC_ARRAY = 0x1,
        ALLOC_DYNAMIC = 0x2
    };

    /**
    */
    template< class DATATYPE >
    class ArrayStream : public osgCompute::BufferStream<DATATYPE>
    {
    public:
        cudaArray*                      _devArray;
        bool                            _devArrayAllocated;
        bool                            _syncDevice;
        DATATYPE*                       _hostPtr;
        bool                            _hostPtrAllocated;
        bool                            _syncHost;
        unsigned int                    _allocHint;

        ArrayStream();
        virtual ~ArrayStream();

    private:
        // not allowed to call copy-constructor or copy-operator
        ArrayStream( const ArrayStream& ) {}
        ArrayStream& operator=( const ArrayStream& ) { return *this; }
    };

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    ArrayStream<DATATYPE>::ArrayStream()
        :   osgCompute::BufferStream<DATATYPE>(),
        _devArray(NULL),
        _hostPtr(NULL),
        _syncDevice(false),
        _syncHost(false),
        _allocHint(0),
        _devArrayAllocated(false),
        _hostPtrAllocated(false)
    {
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    ArrayStream<DATATYPE>::~ArrayStream()
    {
        if( _devArrayAllocated && NULL != _devArray )
            static_cast<Context*>(osgCompute::BufferStream<DATATYPE>::_context.get())->freeMemory( _devArray );

        if( _hostPtrAllocated && NULL != _hostPtr)
            static_cast<Context*>(osgCompute::BufferStream<DATATYPE>::_context.get())->freeMemory( _hostPtr );
    }

    /**
    */
    template< class DATATYPE >
    class Array : public osgCompute::Buffer<DATATYPE>
    {

    public:
        Array();

        META_Object( osgCuda, Array )

        virtual bool init();
        virtual void clear();

        virtual cudaArray* mapArray( const osgCompute::Context& context, unsigned int mapping, unsigned int streamIdx = 0 ) const;
        virtual DATATYPE* map( const osgCompute::Context& context, unsigned int mapping, unsigned int streamIdx = 0 ) const;
        virtual void unmap( const osgCompute::Context& context, unsigned int streamIdx = 0 ) const;

        virtual void setImage( osg::Image* image, unsigned int streamIdx = 0 );
        virtual osg::Image* getImage( unsigned int streamIdx = 0 );
        virtual const osg::Image* getImage( unsigned int streamIdx = 0 ) const;

        virtual void setVector( std::vector<DATATYPE>* streamVector, unsigned int numElements = UINT_MAX, unsigned int offset = 0, unsigned int streamIdx = 0 );
        virtual std::vector<DATATYPE>* getVector( unsigned int streamIdx = 0 );
        virtual const std::vector<DATATYPE>* getVector( unsigned int streamIdx = 0 ) const;

        inline void setChannelFormatDesc(cudaChannelFormatDesc& channelFormatDesc);
        inline cudaChannelFormatDesc& getChannelFormatDesc();
        inline const cudaChannelFormatDesc& getChannelFormatDesc() const;

    protected:
        virtual ~Array() { clearLocal(); }
        void clearLocal();

        DATATYPE* mapStream( ArrayStream<DATATYPE>& stream, unsigned int mapping ) const;
        cudaArray* mapArrayStream( ArrayStream<DATATYPE>& stream, unsigned int mapping ) const;
        void unmapStream( ArrayStream<DATATYPE>& stream ) const;

        bool setupStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const;
        bool allocStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const;
        bool syncStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const;

        virtual osgCompute::BufferStream<DATATYPE>* newStream( const osgCompute::Context& context, unsigned int streamIdx ) const;

        struct StreamData
        {
            std::vector<DATATYPE>        _vector;
            osg::ref_ptr<osg::Image>     _image;
        };

        cudaChannelFormatDesc                       _channelFormatDesc;
        mutable std::vector< StreamData >           _streamSetupList;

    private:
        // copy constructor and operator should not be called
        Array( const Array&, const osg::CopyOp& ) {}
        Array& operator=( const Array& copy ) { return (*this); }
    };

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    Array<DATATYPE>::Array()
        : osgCompute::Buffer<DATATYPE>()
    {
        clearLocal();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::clear()
    {
        clearLocal();
        osgCompute::Buffer<DATATYPE>::clear();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    bool Array<DATATYPE>::init()
    {

        if( osgCompute::Buffer<DATATYPE>::getNumDimensions() > 3 )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::init() for array \""
                << osg::Object::getName() <<"\": the maximum dimension allowed is 3."
                << std::endl;

            clear();
            return false;
        }

        unsigned int numElements = 1;
        for( unsigned int d=0; d<osgCompute::Buffer<DATATYPE>::getNumDimensions(); ++d )
            numElements *= osgCompute::Buffer<DATATYPE>::getDimension( d );

        unsigned int streamSize = numElements * sizeof( DATATYPE );

        // check stream data
        for( unsigned int i=0; i<_streamSetupList.size(); ++i )
        {
            if( _streamSetupList[i]._image.valid() )
            {
                if( _streamSetupList[i]._image->getNumMipmapLevels() > 1 )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::init() for array \""
                        << osg::Object::getName() <<"\": image \""
                        << _streamSetupList[i]._image->getName() << "\" for stream \""<<i<<"\" uses MipMaps which are currently"
                        << "not supported."
                        << std::endl;

                    clear();
                    return false;
                }

                if( _streamSetupList[i]._image->getTotalSizeInBytes() != streamSize )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::init() for array \""
                        << osg::Object::getName() <<"\": size of image \""
                        << _streamSetupList[i]._image->getName() << "\" does not match the array size."
                        << std::endl;

                    clear();
                    return false;
                }
            }
            else if( !_streamSetupList[i]._vector.empty() )
            {
                if( _streamSetupList[i]._vector.size() != numElements )
                    _streamSetupList[i]._vector.resize( numElements );
            }
        }



        return osgCompute::Buffer<DATATYPE>::init();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    cudaArray* osgCuda::Array<DATATYPE>::mapArray( const osgCompute::Context& context, unsigned int mapping, unsigned int streamIdx /*= 0*/ ) const
    {
        if( osgCompute::Param::isDirty() )
        {
            osg::notify(osg::INFO)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": array is dirty."
                << std::endl;

            return NULL;
        }

        if( static_cast<const Context*>(&context)->getAssignedThread() != OpenThreads::Thread::CurrentThread() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": calling thread differs from the context's thread."
                << std::endl;

            return NULL;
        }

        if( mapping & osgCompute::MAP_HOST )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::mapArray() for array \""
                << osg::Object::getName() <<"\": cannot map array to host. Call map() instead."
                << std::endl;

            return NULL;
        }

        ArrayStream<DATATYPE>* stream = static_cast<ArrayStream<DATATYPE>*>( osgCompute::Buffer<DATATYPE>::lookupStream(context,streamIdx) );
        if( NULL == stream )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": cannot receive ArrayStream for context \""
                << context.getId() << "\" and stream \""
                << streamIdx << "\"."
                << std::endl;
            return NULL;
        }

        cudaArray* ptr = NULL;
        if( mapping != osgCompute::UNMAPPED )
            ptr = mapArrayStream( *stream, mapping );
        else
            unmapStream( *stream );

        return ptr;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    DATATYPE* Array<DATATYPE>::map( const osgCompute::Context& context, unsigned int mapping, unsigned int streamIdx /*= 0*/ ) const
    {
        if( osgCompute::Param::isDirty() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": array is dirty."
                << std::endl;

            return NULL;
        }

        if( static_cast<const Context*>(&context)->getAssignedThread() != OpenThreads::Thread::CurrentThread() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": calling thread differs from the context assigned thread."
                << std::endl;

            return NULL;
        }

        if( mapping & osgCompute::MAP_DEVICE )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::mapArray() for array \""
                << osg::Object::getName() <<"\": cannot map buffer to device. Call mapArray() instead."
                << std::endl;

            return NULL;
        }

        ArrayStream<DATATYPE>* stream = static_cast<ArrayStream<DATATYPE>*>( osgCompute::Buffer<DATATYPE>::lookupStream(context,streamIdx) );
        if( NULL == stream )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": cannot receive ArrayStream for context \""
                << context.getId() << "\" and stream \""
                << streamIdx << "\"."
                << std::endl;
            return NULL;
        }

        DATATYPE* ptr = NULL;
        if( mapping != osgCompute::UNMAPPED )
            ptr = mapStream( *stream, mapping );
        else
            unmapStream( *stream );

        return ptr;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::unmap( const osgCompute::Context& context, unsigned int streamIdx /*= 0*/ ) const
    {
        if( osgCompute::Param::isDirty() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": array is dirty."
                << std::endl;

            return;
        }

        if( static_cast<const Context*>(&context)->getAssignedThread() != OpenThreads::Thread::CurrentThread() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": calling thread differs from the context's thread."
                << std::endl;

            return;
        }

        ArrayStream<DATATYPE>* stream = static_cast<ArrayStream<DATATYPE>*>( osgCompute::Buffer<DATATYPE>::lookupStream(context,streamIdx) );
        if( NULL == stream )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Array::map() for array \""
                << osg::Object::getName() <<"\": could not receive ArrayStream for context \""
                << context.getId() << "\" and stream \""
                << streamIdx << "\"."
                << std::endl;

            return;
        }

        unmapStream( *stream );
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    cudaArray* Array<DATATYPE>::mapArrayStream( ArrayStream<DATATYPE>& stream, unsigned int mapping ) const
    {
        ///////////////////
        // PROOF MAPPING //
        ///////////////////
        if( stream._mapping == mapping )
        {
            return stream._devArray;
        }
        else if( stream._mapping != osgCompute::UNMAPPED )
        {
            unmapStream( stream );
        }

        stream._mapping = mapping;

        //////////////
        // MAP DATA //
        //////////////
        bool firstLoad = false;
        cudaArray* ptr = NULL;
        if( (stream._mapping & osgCompute::MAP_DEVICE) )
        {
            if( NULL == stream._devArray )
            {
                ////////////////////////////
                // ALLOCATE DEVICE-MEMORY //
                ////////////////////////////
                if( !allocStream( mapping, stream ) )
                    return NULL;

                firstLoad = true;
            }


            //////////////////
            // SETUP STREAM //
            //////////////////
            if( _streamSetupList.size() > stream._streamIdx &&
                (!_streamSetupList[stream._streamIdx]._vector.empty() ||
                _streamSetupList[stream._streamIdx]._image.valid()) &&
                stream._needsSetup )
                if( !setupStream( mapping, stream ) )
                    return NULL;

            /////////////////
            // SYNC STREAM //
            /////////////////
            if( stream._syncDevice && NULL != stream._hostPtr )
                if( !syncStream( mapping, stream ) )
                    return NULL;

            ptr = stream._devArray;
        }
        else
        {
            osg::notify(osg::WARN)
                << "CUDA::Array::mapArrayStream() for array \""<<osg::Object::getName()<<"\": wrong mapping was specified. Use one of the following: "
                << "DEVICE_SOURCE, DEVICE_TARGET, DEVICE."
                << std::endl;

            return NULL;
        }

        //////////////////
        // LOAD/SUBLOAD //
        //////////////////
        if( osgCompute::Param::getSubloadCallback() && NULL != ptr )
        {
            const osgCompute::BufferSubloadCallback* callback = osgCompute::Param::getSubloadCallback()->asBufferSubloadCallback();
            if( callback )
            {
                // load or subload data before returning the host pointer
                if( firstLoad )
                    callback->load( reinterpret_cast<void*>(ptr), stream._streamIdx, mapping, *this, *stream._context );
                else
                    callback->subload( reinterpret_cast<void*>(ptr), stream._streamIdx, mapping, *this, *stream._context );
            }
        }

        return ptr;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    DATATYPE* Array<DATATYPE>::mapStream( ArrayStream<DATATYPE>& stream, unsigned int mapping ) const
    {
        ///////////////////
        // PROOF MAPPING //
        ///////////////////
        if( stream._mapping == mapping )
        {
            return stream._hostPtr;
        }
        else if( stream._mapping != osgCompute::UNMAPPED )
        {
            unmapStream( stream );
        }

        stream._mapping = mapping;

        //////////////
        // MAP DATA //
        //////////////
        bool firstLoad = false;
        DATATYPE* ptr = NULL;
        if( (stream._mapping & osgCompute::MAP_HOST) )
        {
            if( NULL == stream._hostPtr )
            {
                //////////////////////////
                // ALLOCATE HOST-MEMORY //
                //////////////////////////
                if( !allocStream( mapping, stream ) )
                    return NULL;

                firstLoad = true;
            }

            //////////////////
            // SETUP STREAM //
            //////////////////
            if( _streamSetupList.size() > stream._streamIdx &&
                (!_streamSetupList[stream._streamIdx]._vector.empty() ||
                _streamSetupList[stream._streamIdx]._image.valid()) &&
                stream._needsSetup )
                if( !setupStream( mapping, stream ) )
                    return NULL;

            /////////////////
            // SYNC STREAM //
            /////////////////
            if( stream._syncHost && NULL != stream._devArray )
                if( !syncStream( mapping, stream ) )
                    return NULL;

            ptr = stream._hostPtr;
        }
        else
        {
            osg::notify(osg::WARN)
                << "CUDA::Array::mapStream() for array \""<<osg::Object::getName()<<"\": wrong mapping was specified. Use one of the following: "
                << "HOST_SOURCE, HOST_TARGET, HOST."
                << std::endl;

            return NULL;
        }

        //////////////////
        // LOAD/SUBLOAD //
        //////////////////
        if( osgCompute::Param::getSubloadCallback() && NULL != ptr )
        {
            const osgCompute::BufferSubloadCallback* callback = osgCompute::Param::getSubloadCallback()->asBufferSubloadCallback();
            if( callback )
            {
                // load or subload data before returning the host pointer
                if( firstLoad )
                    callback->load( ptr, stream._streamIdx, mapping, *this, *stream._context );
                else
                    callback->subload( ptr, stream._streamIdx, mapping, *this, *stream._context );
            }
        }

        return ptr;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    bool Array<DATATYPE>::setupStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const
    {
        StreamData& setupData = _streamSetupList[stream._streamIdx];
        cudaError res;

        if( mapping & osgCompute::MAP_DEVICE )
        {
            void* data = NULL;
            if( setupData._image.valid() )
            {
                data = setupData._image->data();
            }
            else if( !setupData._vector.empty() )
            {
                data = &setupData._vector.front();
            }

            if( data == NULL )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""<< osg::Object::getName()
                    << "\": Cannot receive valid data pointer."
                    << std::endl;

                return false;
            }

            res = cudaMemcpyToArray(stream._devArray,0,0,data, osgCompute::Buffer<DATATYPE>::getStreamSize(), cudaMemcpyHostToDevice);
            if( cudaSuccess != res )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""<< osg::Object::getName()
                    << "\": cudaMemcpyToArray() failed for image \""
                    << setupData._image->getName()<< "\" within context \""
                    << stream._context->getId() << "\" and stream \""
                    << stream._streamIdx << "\". Returned code is " << std::hex<<res<<"."
                    << std::endl;

                return false;
            }

            // host must be synchronized
            stream._syncHost = true;
            stream._needsSetup = false;
            return true;
        }
        else if( mapping & osgCompute::MAP_HOST )
        {
            void* data = NULL;
            if( setupData._image.valid() )
            {
                data = setupData._image->data();
            }
            else if( !setupData._vector.empty() )
            {
                data = &setupData._vector.front();
            }

            if( data == NULL )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""<< osg::Object::getName()
                    << "\": Cannot receive valid data pointer."
                    << std::endl;

                return false;
            }

            res = cudaMemcpy( stream._hostPtr,  setupData._image->data(), osgCompute::Buffer<DATATYPE>::getStreamSize(), cudaMemcpyHostToHost );
            if( cudaSuccess != res )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""<< osg::Object::getName()
                    << "\": cudaMemcpy() failed for image \""
                    << setupData._image->getName()<< "\" within context \""
                    << stream._context->getId() << "\" and stream \""
                    << stream._streamIdx << "\". Returned code is " << std::hex<<res<<"."
                    << std::endl;

                return false;
            }

            // device must be synchronized
            stream._syncDevice = true;
            stream._needsSetup = false;
            return true;
        }

        return false;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    bool Array<DATATYPE>::allocStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const
    {
        if( mapping & osgCompute::MAP_HOST )
        {
            if( stream._hostPtr != NULL )
                return true;

            if( (stream._allocHint & ALLOC_DYNAMIC) == ALLOC_DYNAMIC )
            {
                stream._hostPtr = reinterpret_cast<DATATYPE*>(
                    static_cast<Context*>(stream._context.get())->mallocDeviceHostMemory( osgCompute::Buffer<DATATYPE>::getStreamSize() ) );
                if( NULL == stream._hostPtr )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::allocStream() for array \""
                        << osg::Object::getName()<<"\": something goes wrong within mallocDeviceHostMemory() within Context \""<<stream._context->getId()
                        << "\" and Stream \""
                        << stream._streamIdx << "\"."
                        << std::endl;

                    return false;
                }

                stream._hostPtrAllocated = true;
                return true;
            }
            else
            {
                stream._hostPtr = reinterpret_cast<DATATYPE*>(
                    static_cast<Context*>(stream._context.get())->mallocHostMemory( osgCompute::Buffer<DATATYPE>::getStreamSize() ) );
                if( NULL == stream._hostPtr )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::allocStream() for array \""
                        << osg::Object::getName()<<"\": something goes wrong within mallocHostMemory() within Context \""<<stream._context->getId()
                        << "\" and Stream \""
                        << stream._streamIdx << "\"."
                        << std::endl;

                    return false;
                }

                stream._hostPtrAllocated = true;
                return true;
            }
        }
        else if( mapping & osgCompute::MAP_DEVICE )
        {
            if( stream._devArray != NULL )
                return true;

            const cudaChannelFormatDesc& desc = getChannelFormatDesc();
            if( desc.x == INT_MAX && desc.y == INT_MAX && desc.z == INT_MAX )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::allocStream() for array \""<<osg::Object::getName()<<"\": no valid ChannelFormatDesc found."
                    << std::endl;

                return false;
            }

            if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 3 )
            {
                stream._devArray = static_cast<Context*>(stream._context.get())->mallocDevice3DArray(
                                        osgCompute::Buffer<DATATYPE>::getDimension(0),
                                        (osgCompute::Buffer<DATATYPE>::getDimension(1) <= 1)? 0 : osgCompute::Buffer<DATATYPE>::getDimension(1),
                                        (osgCompute::Buffer<DATATYPE>::getDimension(2) <= 1)? 0 : osgCompute::Buffer<DATATYPE>::getDimension(2),
                                        desc );
                if( NULL == stream._devArray )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::allocStream() for array \""<<osg::Object::getName()<<"\": something goes wrong within mallocDevice3DArray() within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\"."
                        << std::endl;

                    return false;
                }
            }
            else if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 2 )
            {
                stream._devArray = static_cast<Context*>(stream._context.get())->mallocDevice2DArray(
                                        osgCompute::Buffer<DATATYPE>::getDimension(0),
                                        (osgCompute::Buffer<DATATYPE>::getDimension(1) <= 1)? 0 : osgCompute::Buffer<DATATYPE>::getDimension(1),
                                        desc );
                if( NULL == stream._devArray )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::allocStream() for array \""<<osg::Object::getName()<<"\": something goes wrong within mallocDevice2DArray() within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\"."
                        << std::endl;

                    return false;
                }
            }
            else if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 2 )
            {
                stream._devArray = static_cast<Context*>(stream._context.get())->mallocDeviceArray(
                                        osgCompute::Buffer<DATATYPE>::getDimension(0),
                                        desc );
                if( NULL == stream._devArray )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::allocStream() for array \""<<osg::Object::getName()<<"\": something goes wrong within mallocDeviceArray() within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\"."
                        << std::endl;

                    return false;
                }
            }

            stream._devArrayAllocated = true;
            return true;
        }

        return false;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    bool Array<DATATYPE>::syncStream( unsigned int mapping, ArrayStream<DATATYPE>& stream ) const
    {
        cudaError res;
        if( mapping & osgCompute::MAP_DEVICE )
        {
            if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 1 )
            {
                res = cudaMemcpyToArray( stream._devArray, 0, 0, stream._hostPtr, osgCompute::Buffer<DATATYPE>::getStreamSize(), cudaMemcpyHostToDevice );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""<<osg::Object::getName()
                        << "\": something goes wrong on cudaMemcpyToArray() to device within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;
                    return false;
                }
            }
            else if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 2 )
            {
                res = cudaMemcpy2DToArray( stream._devArray,
                    0, 0,
                    stream._hostPtr,
                    osgCompute::Buffer<DATATYPE>::getDimension(0) * sizeof(DATATYPE),
                    osgCompute::Buffer<DATATYPE>::getDimension(0),
                    osgCompute::Buffer<DATATYPE>::getDimension(1),
                    cudaMemcpyHostToDevice );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""<<osg::Object::getName()
                        << "\": something goes wrong on cudaMemcpy2DToArray() to device within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;

                    return false;
                }
            }
            else
            {
                cudaPitchedPtr pitchPtr = {0};
                pitchPtr.pitch = osgCompute::Buffer<DATATYPE>::getDimension(0)*sizeof(DATATYPE);
                pitchPtr.ptr = (void*)stream._hostPtr;
                pitchPtr.xsize = osgCompute::Buffer<DATATYPE>::getDimension(0);
                pitchPtr.ysize = osgCompute::Buffer<DATATYPE>::getDimension(1);

                cudaExtent extent = {0};
                extent.width = osgCompute::Buffer<DATATYPE>::getDimension(0);
                extent.height = osgCompute::Buffer<DATATYPE>::getDimension(1);
                extent.depth = osgCompute::Buffer<DATATYPE>::getDimension(2);

                cudaMemcpy3DParms copyParams = {0};
                copyParams.srcPtr = pitchPtr;
                copyParams.dstArray = stream._devArray;
                copyParams.extent = extent;
                copyParams.kind = cudaMemcpyHostToDevice;

                res = cudaMemcpy3D( &copyParams );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""<<osg::Object::getName()
                        << "\": something goes wrong on cudaMemcpy3D() to device within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;

                    return false;
                }
            }

            stream._syncDevice = false;
            return true;
        }
        else if( mapping & osgCompute::MAP_HOST )
        {
            if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 1 )
            {
                res = cudaMemcpyFromArray( stream._hostPtr, stream._devArray, 0, 0, osgCompute::Buffer<DATATYPE>::getStreamSize(), cudaMemcpyDeviceToHost );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""
                        << osg::Object::getName()<<"\": something goes wrong within cudaMemcpyFromArray() to host within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;

                    return false;
                }
            }
            else if( osgCompute::Buffer<DATATYPE>::getNumDimensions() == 2 )
            {
                res = cudaMemcpy2DFromArray(
                    stream._hostPtr,
                    osgCompute::Buffer<DATATYPE>::getDimension(0) * sizeof(DATATYPE),
                    stream._devArray,
                    0, 0,
                    osgCompute::Buffer<DATATYPE>::getDimension(0),
                    osgCompute::Buffer<DATATYPE>::getDimension(1),
                    cudaMemcpyDeviceToHost );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""<<osg::Object::getName()
                        << "\": something goes wrong on cudaMemcpy2DFromArray() to device within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;

                    return false;
                }
            }
            else
            {
                cudaPitchedPtr pitchPtr = {0};
                pitchPtr.pitch = osgCompute::Buffer<DATATYPE>::getDimension(0)*sizeof(DATATYPE);
                pitchPtr.ptr = (void*)stream._hostPtr;
                pitchPtr.xsize = osgCompute::Buffer<DATATYPE>::getDimension(0);
                pitchPtr.ysize = osgCompute::Buffer<DATATYPE>::getDimension(1);

                cudaExtent extent = {0};
                extent.width = osgCompute::Buffer<DATATYPE>::getDimension(0);
                extent.height = osgCompute::Buffer<DATATYPE>::getDimension(1);
                extent.depth = osgCompute::Buffer<DATATYPE>::getDimension(2);

                cudaMemcpy3DParms copyParams = {0};
                copyParams.srcArray = stream._devArray;
                copyParams.dstPtr = pitchPtr;
                copyParams.extent = extent;
                copyParams.kind = cudaMemcpyDeviceToHost;

                res = cudaMemcpy3D( &copyParams );
                if( cudaSuccess != res )
                {
                    osg::notify(osg::FATAL)
                        << "CUDA::Array::syncStream() for array \""<<osg::Object::getName()
                        << "\": something goes wrong on cudaMemcpy3D() to device within context \""
                        << stream._context->getId() << "\" and stream \""
                        << stream._streamIdx << "\". Returned code is "
                        << std::hex<<res<<"."
                        << std::endl;

                    return false;

                }
            }

            stream._syncHost = false;
            return true;
        }

        return false;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::unmapStream( ArrayStream<DATATYPE>& stream ) const
    {
        if( (stream._mapping & osgCompute::MAP_HOST_TARGET) )
        {
            stream._syncDevice = true;
        }
        else if( (stream._mapping & osgCompute::MAP_DEVICE_TARGET) )
        {
            stream._syncHost = true;
        }

        stream._mapping = osgCompute::UNMAPPED;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::setImage( osg::Image* image, unsigned int streamIdx )
    {
        if( _streamSetupList.size() <= streamIdx )
            _streamSetupList.resize( streamIdx + 1 );

        if( !osgCompute::Param::isDirty() && NULL != image)
        {
            if( image->getNumMipmapLevels() > 1 )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""
                    << osg::Object::getName() <<"\": image \""
                    << image->getName() << "\" for stream \""<<streamIdx<<"\" uses MipMaps which are currently"
                    << "not supported."
                    << std::endl;

                return;
            }

            if( image->getTotalSizeInBytes() != osgCompute::Buffer<DATATYPE>::getStreamSize() )
            {
                osg::notify(osg::FATAL)
                    << "CUDA::Array::setupStream() for array \""
                    << osg::Object::getName() <<"\": size of image \""
                    << image->getName() << "\" does not match the buffer size."
                    << std::endl;

                return;
            }
        }

        _streamSetupList[streamIdx]._image = image;
        _streamSetupList[streamIdx]._vector.clear();

        osgCompute::Buffer<DATATYPE>::setNeedsSetup( (image != NULL), streamIdx );
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    osg::Image* Array<DATATYPE>::getImage( unsigned int streamIdx )
    {
        if( _streamSetupList.size() <= streamIdx )
            return NULL;

        return _streamSetupList[streamIdx]._image.get();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    const osg::Image* Array<DATATYPE>::getImage( unsigned int streamIdx ) const
    {
        if( (_streamSetupList.size()-1) < streamIdx )
            return NULL;

        return _streamSetupList[streamIdx]._image.get();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::setVector( std::vector<DATATYPE>* data, unsigned int numElements, unsigned int offset, unsigned int streamIdx )
    {
        if( _streamSetupList.size() <= streamIdx )
            _streamSetupList.resize( streamIdx + 1, StreamData() );

        if( data != NULL )
        {
            unsigned int numElementsToCopy = (numElements == UINT_MAX)? data->size() : numElements;
            if( numElementsToCopy == 0 )
                return;

            if( !osgCompute::Param::isDirty() )
            {
                if( _streamSetupList[streamIdx]._vector.size() < osgCompute::Buffer<DATATYPE>::getNumElements() )
                    _streamSetupList[streamIdx]._vector.resize( osgCompute::Buffer<DATATYPE>::getNumElements() );

                // if streamsize is known then check for overwrites
                if( (offset + numElementsToCopy) > osgCompute::Buffer<DATATYPE>::getNumElements() )
                    numElementsToCopy = (osgCompute::Buffer<DATATYPE>::getNumElements() - offset);
            }
            else
            {
                if( _streamSetupList[streamIdx]._vector.size() < (numElementsToCopy + offset) )
                    _streamSetupList[streamIdx]._vector.resize( numElementsToCopy + offset );
            }


            memcpy( &_streamSetupList[streamIdx]._vector.at(offset), &data->front(), numElementsToCopy * sizeof(DATATYPE) );
        }
        else
        {
            _streamSetupList[streamIdx]._vector.clear();
        }

        _streamSetupList[streamIdx]._image = NULL;
        osgCompute::Buffer<DATATYPE>::setNeedsSetup( (data != NULL), streamIdx);
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    std::vector<DATATYPE>* Array<DATATYPE>::getVector( unsigned int streamIdx )
    {
        if( _streamSetupList.size() <= streamIdx )
            return NULL;

        return &_streamSetupList[streamIdx]._vector;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    const std::vector<DATATYPE>* Array<DATATYPE>::getVector( unsigned int streamIdx ) const
    {
        if( _streamSetupList.size() <= streamIdx )
            return NULL;

        return &_streamSetupList[streamIdx]._vector;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    cudaChannelFormatDesc& osgCuda::Array<DATATYPE>::getChannelFormatDesc()
    {
        return _channelFormatDesc;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    const cudaChannelFormatDesc& osgCuda::Array<DATATYPE>::getChannelFormatDesc() const
    {
        return _channelFormatDesc;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void osgCuda::Array<DATATYPE>::setChannelFormatDesc(cudaChannelFormatDesc& channelFormatDesc)
    {
        if( !osgCompute::Param::isDirty() )
            return;

        _channelFormatDesc = channelFormatDesc;
    }

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PROTECTED FUNCTIONS //////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Array<DATATYPE>::clearLocal()
    {
        _streamSetupList.clear();
        memset( &_channelFormatDesc, INT_MAX, sizeof(cudaChannelFormatDesc) );
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    osgCompute::BufferStream<DATATYPE>* Array<DATATYPE>::newStream( const osgCompute::Context& context, unsigned int streamIdx ) const
    {
        return new ArrayStream<DATATYPE>;
    }
}

#endif //OSGCUDA_ARRAY
