/* osgCompute - Copyright (C) 2008-2009 SVT Group
 *                                                                     
 * This library is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 3 of
 * the License, or (at your option) any later version.
 *                                                                     
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of 
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesse General Public License for more details.
 *
 * The full license is in LICENSE file included with this distribution.
*/

#ifndef OSGCOMPUTE_BUFFER
#define OSGCOMPUTE_BUFFER 1

#include "osgCompute/Param"

namespace osg
{
    class Image;
    class Vec2b;
    class Vec3b;
    class Vec4b;
    class Vec4ub;
    class Vec2s;
    class Vec3s;
    class Vec4s;
    class Vec2f;
    class Vec3f;
    class Vec4f;
    class Vec2d;
    class Vec3d;
    class Vec4d;
}

namespace osgCompute
{
    template< class DATATYPE > class Buffer;

    typedef Buffer<unsigned char>     UByteBuffer;
    typedef Buffer<osg::Vec4ub>       Vec4ubBuffer;
    typedef Buffer<char>              ByteBuffer;
    typedef Buffer<osg::Vec2b>        Vec2bBuffer;
    typedef Buffer<osg::Vec3b>        Vec3bBuffer;
    typedef Buffer<osg::Vec4b>        Vec4bBuffer;
    typedef Buffer<unsigned short>    UShortBuffer;
    typedef Buffer<short>             ShortBuffer;
    typedef Buffer<osg::Vec2s>        Vec2sBuffer;
    typedef Buffer<osg::Vec3s>        Vec3sBuffer;
    typedef Buffer<osg::Vec4s>        Vec4sBuffer;
    typedef Buffer<unsigned int>      UIntBuffer;
    typedef Buffer<int>               IntBuffer;
    typedef Buffer<unsigned long>     ULongBuffer;
    typedef Buffer<long>              LongBuffer;
    typedef Buffer<float>             FloatBuffer;
    typedef Buffer<osg::Vec2f>        Vec2fBuffer;
    typedef Buffer<osg::Vec3f>        Vec3fBuffer;
    typedef Buffer<osg::Vec4f>        Vec4fBuffer;
    typedef Buffer<double>            DoubleBuffer;
    typedef Buffer<osg::Vec2d>        Vec2dBuffer;
    typedef Buffer<osg::Vec3d>        Vec3dBuffer;
    typedef Buffer<osg::Vec4d>        Vec4dBuffer;

    enum Mapping
    {
        UNMAPPED                       = 0x00000000,
        MAP_HOST                       = 0x00000011,
        MAP_HOST_SOURCE                = 0x00000001,
        MAP_HOST_TARGET                = 0x00000010,
        MAP_DEVICE                     = 0x00110000,
        MAP_DEVICE_SOURCE              = 0x00010000,
        MAP_DEVICE_TARGET              = 0x00100000,
    }; 


    enum ALLOC_HINT
    {
        NO_ALLOC_HINT = 0x0,
        ALLOC_DYNAMIC = 0x1
    };

    /**
    */
    template< class DATATYPE >
    class BufferStream
    {
    public:
        unsigned int                    _mapping;
        unsigned int                    _streamIdx;
        osg::ref_ptr<Context>           _context;
        unsigned int                    _allocHint;
        bool                            _needsSetup;

        BufferStream();
        virtual ~BufferStream();

    private:
        // not allowed to call copy-constructor or copy-operator
        BufferStream( const BufferStream& ) {}
        BufferStream& operator=( const BufferStream& ) { return *this; }
    };

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    BufferStream<DATATYPE>::BufferStream() 
        :   _mapping( UNMAPPED ),
            _streamIdx(UINT_MAX),
            _allocHint(0),
            _needsSetup(true)
    {
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    BufferStream<DATATYPE>::~BufferStream() 
    {
    }

    /**
    */
    template< class DATATYPE >
    class Buffer : public Param
    {  
    public:
        Buffer();

        virtual bool init();
        virtual void clear();
        virtual bool isBuffer() { return true; }

        virtual DATATYPE* map( const Context& context, unsigned int mapping, unsigned int streamIdx = 0 ) const = 0;
        virtual void unmap( const Context& context, unsigned int streamIdx = 0 ) const = 0;

        virtual void setImage( osg::Image* streamImage, unsigned int streamIdx = 0 ) = 0;
        virtual osg::Image* getImage( unsigned int streamIdx = 0 ) = 0;
        virtual const osg::Image* getImage( unsigned int streamIdx = 0 ) const = 0;

        virtual void setVector( std::vector<DATATYPE>* streamVector, unsigned int numElements = UINT_MAX, unsigned int offset = 0, unsigned int streamIdx = 0 ) = 0;
        virtual std::vector<DATATYPE>* getVector( unsigned int streamIdx = 0 ) = 0;
        virtual const std::vector<DATATYPE>* getVector( unsigned int streamIdx = 0 ) const = 0;

        virtual unsigned int getByteSize() const;
		inline unsigned int getStreamSize() const;
        inline unsigned int getNumStreams() const;
        inline void setNumStreams( unsigned int numStreams );

        inline void setDimension( unsigned int dimIdx, unsigned int dimSize );
        inline unsigned int getDimension( unsigned int dimIdx ) const;
        inline unsigned int getNumDimensions() const;

        inline unsigned int getNumElements() const;

        inline void setAllocHint( unsigned int allocHint );
        inline unsigned int getAllocHint() const;

        virtual unsigned int getMapping( Context& context, unsigned int streamIdx = 0 ) const;

    protected:
        virtual ~Buffer() { clearLocal(); }
        inline void clearLocal();

        virtual bool init( const Context& context ) const;
        virtual void clear( const Context& context ) const;

        virtual BufferStream<DATATYPE>* newStream( const Context& context, unsigned int streamIdx ) const = 0;
        inline BufferStream<DATATYPE>* lookupStream( const Context& context, unsigned int streamIdx ) const;

        virtual void setNeedsSetup( bool needsSetup, unsigned int streamIdx ) const;

        unsigned int                                    _allocHint;
		unsigned int							        _byteSize;
        unsigned int                                    _numStreams;
        unsigned int                                    _streamSize;
        std::vector<unsigned int>                       _dimensions;
        unsigned int                                    _numElements;

        mutable OpenThreads::Mutex                      _mutex;
        mutable std::vector<BufferStream<DATATYPE>**>   _streams;

    private:
        // copy constructor and operator should not be called
        Buffer( const Buffer&, const osg::CopyOp& ) {}
        Buffer& operator=( const Buffer& copy ) { return (*this); }
    };

	/////////////////////////////////////////////////////////////////////////////////////////////////
	// PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
	/////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
	Buffer<DATATYPE>::Buffer() 
		: Param()
	{ 
		clearLocal();
	}

    //------------------------------------------------------------------------------
    template<class DATATYPE >
	void Buffer<DATATYPE>::clear()
	{
		clearLocal();
		Param::clear();
	}

    //------------------------------------------------------------------------------
    template< class DATATYPE >
	bool Buffer<DATATYPE>::init()
	{
		if( !isDirty() )
			return true;

		if( _dimensions.empty() )
		{
			osg::notify(osg::FATAL)  
				<< "Buffer::init() for Buffer \""<<getName()<<"\": No Dimensions specified."                  
				<< std::endl;

			return false;
		}

		if( !(_numStreams > 0) )
		{
			osg::notify(osg::FATAL)  
				<< "Buffer::init() for Buffer \""<<getName()<<"\": Number of streams must be specified."                  
				<< std::endl;

			return false;
		}

		///////////////////////
		// COMPUTE BYTE SIZE //
		///////////////////////
		_numElements = 1;
		for( unsigned int d=0; d<_dimensions.size(); ++d )
			_numElements *= _dimensions[d];

		_streamSize = sizeof(DATATYPE) * _numElements;
		_byteSize = _streamSize * _numStreams;

		return Param::init();
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	unsigned int Buffer<DATATYPE>::getByteSize() const 
	{ 
		return _byteSize; 
	}

    //------------------------------------------------------------------------------
	template< class DATATYPE >
    inline void Buffer<DATATYPE>::setNumStreams( unsigned int numStreams )
    {
        if( !isDirty() )
            return;

        _numStreams = numStreams;
    }

	//------------------------------------------------------------------------------
	template< class DATATYPE >
    inline unsigned int Buffer<DATATYPE>::getNumStreams() const
    {
        return _numStreams;
    }

	//------------------------------------------------------------------------------
	template< class DATATYPE >
    inline unsigned int Buffer<DATATYPE>::getStreamSize() const
    {
        return _streamSize;
    }

	//------------------------------------------------------------------------------
	template< class DATATYPE >
    inline void Buffer<DATATYPE>::setDimension( unsigned int dimIdx, unsigned int dimSize )
    {
        if( !isDirty() )
            return;

        if (_dimensions.size()<=dimIdx)
            _dimensions.resize(dimIdx+1,0);

        _dimensions[dimIdx] = dimSize;
    }

	//------------------------------------------------------------------------------
	template< class DATATYPE >
    inline unsigned int Buffer<DATATYPE>::getDimension( unsigned int dimIdx ) const
    { 
        if( dimIdx > (_dimensions.size()-1) )
            return 0;

        return _dimensions[dimIdx];
    }

	//------------------------------------------------------------------------------
	template< class DATATYPE >
    inline unsigned int Buffer<DATATYPE>::getNumDimensions() const
    { 
        return _dimensions.size();
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    unsigned int osgCompute::Buffer<DATATYPE>::getNumElements() const
    {
        return _numElements;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    inline void osgCompute::Buffer<DATATYPE>::setAllocHint( unsigned int allocHint )
    {
        if( !isDirty() )
            return;

        _allocHint = (_allocHint | allocHint);
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    inline unsigned int osgCompute::Buffer<DATATYPE>::getAllocHint() const
    {
        return _allocHint;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    unsigned int Buffer<DATATYPE>::getMapping( osgCompute::Context& context, unsigned int streamIdx ) const
    {
        if( isDirty() )
            return osgCompute::UNMAPPED;

        BufferStream<DATATYPE>* stream = lookupStream( context, streamIdx );
        if( NULL == stream )
        {
            osg::notify(osg::FATAL)  
                << "Buffer::getMapping() for Buffer \""
                << getName() <<"\": Could not receive BufferStream for Context \""
                << context.getId() << "\" and Stream \""<<streamIdx<<"\"."
                << std::endl;

            return osgCompute::UNMAPPED;
        }

        return stream->_mapping;
    }

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PROTECTED FUNCTIONS //////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Buffer<DATATYPE>::clearLocal()
    {
        /////////////////
        // DELETE DATA //
        /////////////////
        for( unsigned int ctx = 0; ctx < _streams.size(); ++ctx )
        {
            Context* context = Context::instance( ctx );
            if( !context )
                continue;

            // Clear context streams
            clear( *context );
        }

        _streams.clear();

        // Clear other members
        _dimensions.clear();
        _numElements = 0;
        _numStreams = 1;
        _byteSize = 0;
        _allocHint = NO_ALLOC_HINT;
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    bool Buffer<DATATYPE>::init( const Context& context ) const
    {
        if( _streams.size()<=context.getId() )
            _streams.resize(context.getId()+1,NULL);

        // Allocate stream array for context
        if( NULL == _streams[context.getId()] )
        {
            _streams[context.getId()] = new BufferStream<DATATYPE>*[_numStreams];

            if( NULL == _streams[context.getId()] )
            {
                osg::notify( osg::FATAL )  
                    << "Buffer::init( \"CONTEXT\" ) for Buffer \"" << getName()
                    << "\": DataArray could be allocated for context \"" 
                    << context.getId() << "\"."
                    << std::endl;

                return false;
            }

            // Initialize stream array
            for( unsigned int str =0; str < _numStreams; ++str )
                _streams[context.getId()][str] = NULL;
        }

        // Register param if valid stream-array 
        // is allocated
        return Param::init( context );
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Buffer<DATATYPE>::clear( const Context& context ) const
    {
        if( _streams.size() > context.getId() &&
            NULL != _streams[context.getId()] )
        {
            BufferStream<DATATYPE>** ctxStreams = _streams[context.getId()];
            for( unsigned int str = 0; str < getNumStreams(); ++str )
            {
                if( NULL != ctxStreams[str] )
                {
                    // Delete lacy allocated stream
                    delete( ctxStreams[str] );
                    ctxStreams[str] = NULL;
                }
            }

            // Delete stream array for context
            delete [] ctxStreams;
            _streams[context.getId()] = NULL;
        }

        // Unregister context
        return Param::clear( context );
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    inline BufferStream<DATATYPE>* Buffer<DATATYPE>::lookupStream( const Context& context, unsigned int streamIdx ) const
    {
        if( streamIdx > _numStreams )
        {
            osg::notify( osg::FATAL )  
                << "Buffer::lookupStream() for Buffer \"" << getName()
                << "\": Index of Stream does not exist."
                << std::endl;

            return NULL;
        }

        OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_mutex);

        //////////////////
        // CONTEXT INIT //
        //////////////////
        // Init stream array and register at context
        if( _streams.size()<=context.getId() ||
            _streams[context.getId()] == NULL )
            init( context );


        BufferStream<DATATYPE>**& ctxStreams = _streams[context.getId()];

        /////////////////////
        // ALLOCATE STREAM //
        /////////////////////
        if( NULL == ctxStreams[ streamIdx ] )
        {
            // lacy allocation of the required stream within the context
            ctxStreams[ streamIdx ] = newStream( context, streamIdx );
            if( ctxStreams[ streamIdx ] == NULL )
            {
                osg::notify( osg::FATAL )  
                    << "Buffer::lookupStream() for Buffer \"" << getName()
                    << "\": Data could not be allocated for stream \""
                    << streamIdx << "\" and context \"" << context.getId() << "\"."
                    << std::endl;

                return NULL;
            }

            // Setup stream params
            ctxStreams[ streamIdx ]->_context = const_cast<osgCompute::Context*>( &context );
            ctxStreams[ streamIdx ]->_streamIdx = streamIdx;
            ctxStreams[ streamIdx ]->_allocHint = getAllocHint();
        }

        
        return ctxStreams[ streamIdx ];
    }

    //------------------------------------------------------------------------------
    template< class DATATYPE >
    void Buffer<DATATYPE>::setNeedsSetup( bool needsSetup, unsigned int streamIdx ) const
    {
        OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_mutex);
        for( unsigned int c=0; c<_streams.size(); ++c )
        {
            BufferStream<DATATYPE>**& ctxStreams = _streams[c];
            if( ctxStreams[streamIdx] )
                ctxStreams[streamIdx]->_needsSetup = needsSetup;
        }
    }
}

#endif //OSGCOMPUTE_BUFFER
