/* osgCompute - Copyright (C) 2008-2009 SVT Group
 *                                                                     
 * This library is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 3 of
 * the License, or (at your option) any later version.
 *                                                                     
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of 
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesse General Public License for more details.
 *
 * The full license is in LICENSE file included with this distribution.
*/

#ifndef OSGCUDA_CONSTANT
#define OSGCUDA_CONSTANT 1

#include <OpenThreads/Thread>
#include <osgCompute/Constant>
#include "osgCuda/Context"
#include "osgCuda/Export"

namespace osg
{
    class Vec2b;
    class Vec3b;
    class Vec4b;
    class Vec4ub;
    class Vec2s;
    class Vec3s;
    class Vec4s;
    class Vec2f;
    class Vec3f;
    class Vec4f;
    class Vec2d;
    class Vec3d;
    class Vec4d;
}

namespace osgCuda
{
    template<class T>
    class Constant;

    typedef Constant<unsigned char>     UByteConstant;
    typedef Constant<char>              ByteConstant;
    typedef Constant<osg::Vec2b>        Vec2bConstant;
    typedef Constant<osg::Vec3b>        Vec3bConstant;
    typedef Constant<osg::Vec4b>        Vec4bConstant;
    typedef Constant<unsigned short>    UShortConstant;
    typedef Constant<short>             ShortConstant;
    typedef Constant<osg::Vec2s>        Vec2sConstant;
    typedef Constant<osg::Vec3s>        Vec3sConstant;
    typedef Constant<osg::Vec4s>        Vec4sConstant;
    typedef Constant<unsigned int>      UIntConstant;
    typedef Constant<int>               IntConstant;
    typedef Constant<unsigned long>     ULongConstant;
    typedef Constant<long>              LongConstant;
    typedef Constant<float>             FloatConstant;
    typedef Constant<osg::Vec2f>        Vec2fConstant;
    typedef Constant<osg::Vec3f>        Vec3fConstant;
    typedef Constant<osg::Vec4f>        Vec4fConstant;
    typedef Constant<double>            DoubleConstant;
    typedef Constant<osg::Vec2d>        Vec2dConstant;
    typedef Constant<osg::Vec3d>        Vec3dConstant;
    typedef Constant<osg::Vec4d>        Vec4dConstant;

    /**
    */
	template< class DATATYPE >
    class ConstantData : public osgCompute::ConstantData<DATATYPE>
    {
    public:
        DATATYPE	_data;
		bool		_isValid;

        ConstantData();
        virtual ~ConstantData();

    private:
        // not allowed to call copy-constructor or copy-operator
        ConstantData( const ConstantData& ) {}
        ConstantData& operator=( const ConstantData& ) { return *this; }
    };

	/////////////////////////////////////////////////////////////////////////////////////////////////
	// PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
	/////////////////////////////////////////////////////////////////////////////////////////////////
	//------------------------------------------------------------------------------
	template< class DATATYPE >
	ConstantData<DATATYPE>::ConstantData()
        : osgCompute::ConstantData<DATATYPE>(),
          _isValid(false)
	{
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	ConstantData<DATATYPE>::~ConstantData()
	{
	}


    /**
    */
	template< class DATATYPE >
    class Constant : public osgCompute::Constant<DATATYPE>
    {
    public:
        Constant();

        META_Object( osgCuda, Constant )

        virtual void clear();
        virtual DATATYPE* data( const osgCompute::Context& context ) const;

		virtual void setData( const DATATYPE& data );
		virtual DATATYPE* getData();
		virtual const DATATYPE* getData() const;

    protected:
        virtual ~Constant() { clearLocal(); }
        void clearLocal();

        virtual ConstantData<DATATYPE>* newData( const osgCompute::Context& context ) const;

		DATATYPE _initialData;
		bool	 _isDataValid;

    private:
        // copy constructor and operator are not allowed
        Constant( const Constant&, const osg::CopyOp& ) {}
        Constant& operator=( const Constant& copy ) { return (*this); }
    };

	/////////////////////////////////////////////////////////////////////////////////////////////////
	// PUBLIC FUNCTIONS /////////////////////////////////////////////////////////////////////////////
	/////////////////////////////////////////////////////////////////////////////////////////////////
	//------------------------------------------------------------------------------
	template< class DATATYPE >
	Constant<DATATYPE>::Constant()
		: osgCompute::Constant<DATATYPE>()
	{
		clearLocal();
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	void Constant<DATATYPE>::clearLocal()
	{
		_isDataValid = false;
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	void Constant<DATATYPE>::clear()
	{
		clearLocal();
		osgCompute::Constant<DATATYPE>::clear();
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	DATATYPE* Constant<DATATYPE>::data( const osgCompute::Context& context ) const
	{
		if( osgCompute::Param::isDirty() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Constant::data() for constant \""
                << osg::Object::getName() <<"\": constant is dirty."
                << std::endl;

            return NULL;
        }

        if( static_cast<const Context*>(&context)->getAssignedThread() != OpenThreads::Thread::CurrentThread() )
        {
            osg::notify(osg::FATAL)
                << "CUDA::Constant::data() for Constant \""
                << osg::Object::getName() <<"\": calling thread differs from the context assigned thread."
                << std::endl;

            return NULL;
        }

		ConstantData<DATATYPE>* ctxData = static_cast<ConstantData<DATATYPE>*>( osgCompute::Constant<DATATYPE>::lookupData(context) );
		if( NULL == ctxData )
		{
			osg::notify(osg::FATAL)
				<< "CUDA::Constant::data() for Constant \""
				<< osg::Object::getName() <<"\": Could not receive ConstantData for Context \""
				<< context.getId() << "\"."
				<< std::endl;

			return NULL;
		}

		bool firstLoad = false;

		if( !ctxData->_isValid && _isDataValid )
		{
			ctxData->_data = _initialData;
			ctxData->_isValid = true;
			firstLoad = true;
		}

        if( osgCompute::Param::getSubloadCallback() )
        {
            const osgCompute::ConstantSubloadCallback* callback = osgCompute::Param::getSubloadCallback()->asConstantSubloadCallback();
            if( callback )
            {
                // load or subload data before returning the host pointer
                if( firstLoad )
                    callback->load( &ctxData->_data, *this, context );
                else
                    callback->subload( &ctxData->_data, *this, context );
            }
        }

		return &ctxData->_data;
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	void Constant<DATATYPE>::setData( const DATATYPE& data )
	{
		_initialData = data;
		_isDataValid = true;

        if( !osgCompute::Param::isDirty() )
        {
		    // if already initialized then update per context
		    // data
		    for( unsigned int ctx = 0; ctx < osgCompute::Constant<DATATYPE>::_array.size(); ++ctx )
		    {
			    ConstantData<DATATYPE>* ctxData = static_cast<ConstantData<DATATYPE>*>( osgCompute::Constant<DATATYPE>::_array[ctx] );
			    if( NULL != ctxData )
			    {
				    ctxData->_data = _initialData;
				    ctxData->_isValid = true;
			    }
		    }
        }
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	DATATYPE* Constant<DATATYPE>::getData()
	{
		return &_initialData;
	}

	//------------------------------------------------------------------------------
	template< class DATATYPE >
	const DATATYPE* Constant<DATATYPE>::getData() const
	{
		return &_initialData;
	}

    /////////////////////////////////////////////////////////////////////////////////////////////////
    // PROTECTED FUNCTIONS //////////////////////////////////////////////////////////////////////////
    /////////////////////////////////////////////////////////////////////////////////////////////////
    //------------------------------------------------------------------------------
    template< class DATATYPE >
    ConstantData<DATATYPE>* Constant<DATATYPE>::newData( const osgCompute::Context& context ) const
    {
        return new ConstantData<DATATYPE>;
    }

}

#endif //OSGCUDA_CONSTANT
