#include "ShaderCommon.fx"
#include "noise.fx"

#define NUM_THREADS_PER_GROUP		64

#define PARTICLES_TTL				4.0f
#define PARTICLES_TTL_INV			(1.0f/PARTICLES_TTL)
#define PARTICLES_EMITTER_OFFSET	30.0f

// STRUCTURES /////////////////////////////////////////////////////////////////
struct Particle
{
	float4 position;
	float4 velocity;
	float4 color;
	float ttl;
	float3 seed;
};

struct ParticleVertexInput
{
	float4 position : POSITION0;
	float2 uv : TEXCOORD0;
	float4 color : COLOR0;
};

struct ParticleVertexOut
{
	float4 position			: SV_POSITION;
	float3 wpos				: WPOS;
	float2 uv				: TEXCOORD0;
	float4 color			: COLOR0;
};

// CONSTANT BUFFERS ///////////////////////////////////////////////////////////
cbuffer UpdateCB
{
	float4 emitterPosition;
	float4 emitterDirection;
	float4 emitterParams;		// x is emission rate, y is inverse emission rate, z is particle TTL
	float4 time;				// x is current time, y is prev time, z is dt

	float4 colorScale;
	float4 extraParams;			// xy is noiseBias in x and y, z is speed, w is spread parameter
};

cbuffer RenderCB
{
	float4 billboardHalfSize;
};

// RESOURCES //////////////////////////////////////////////////////////////////
Texture2D<float4>					ParticleTexture;

RWStructuredBuffer<Particle>		RWParticleBuffer;
StructuredBuffer<Particle>			ParticleBuffer;

RWBuffer<uint>						RWParticleIndexBuffer;
Buffer<uint>						ParticleIndexBuffer;

// HELPER FUNCTIONS ///////////////////////////////////////////////////////////
float3 initializePos(float3 seed)
{
	float3 n = float3(fbm(seed), 
					  fbm(seed*13.0f), 
					  fbm(seed*129.0f));
	n = n * 2.0f - float3(1,1,1);

	float3 offset = float3(0.0f, 0.0f, 0.0f);
	float r = n.x * PARTICLES_EMITTER_OFFSET + extraParams.w;
	float4 sincos = float4(sin(n.y * PI), cos(n.y * PI), sin(n.z * TWOPI), cos(n.z * TWOPI));
	offset.x = r * sincos.x * sincos.w;
	offset.y = r * sincos.y;
	offset.z = r * sincos.x * sincos.z;
	return emitterPosition.xyz + offset;
}

float3 initializeVel(float3 seed)
{
	float3 n = float3(fbm(seed), 
					  fbm(seed*13.0f), 
					  fbm(seed*129.0f));
	n = n * 2.0f - float3(1,1,1);
	n *= extraParams.xyz;
	return float3(0.0f, 0.0f, 0.0f);
}

Particle initializeParticle(float3 seed, float4 color)
{
	Particle p = (Particle)0;
	float n = saturate(noise(seed) + 0.01f);
	p.position.xyz = initializePos(seed);
	p.velocity.xyz = initializeVel(seed);
	p.color = color;
	p.ttl = PARTICLES_TTL;
	p.seed = seed;

	return p;
}

float calcParticleScale(Particle p)
{
	//return saturate((p.ttl * PARTICLES_TTL_INV));
	//return 0.6f;
	return p.color.a * (p.ttl * PARTICLES_TTL_INV) * 0.01f;
}

float3 calcNoiseDerivative(float3 seed)
{
	const float2 eps = float2(0.001f, 0.0f);
	float3 nd = float3(0.0f, 0.0f, 0.0f);
	nd.x = fbm(seed + eps.xyy) - fbm(seed - eps.xyy);
	nd.y = fbm(seed + eps.yxy) - fbm(seed - eps.yxy);
	nd.z = fbm(seed + eps.yyx) - fbm(seed - eps.yyx);
	return nd;
}

float3 computeCurl(float3 deriv0, float3 deriv1, float3 deriv2)
{
	return float3(deriv2.y - deriv1.z,
				  deriv0.z - deriv2.x,
				  deriv1.x - deriv0.y);
}

// INITIALIZE PARTICLES ///////////////////////////////////////////////////////
//[numthreads(8, 4, 1)]
[numthreads(NUM_THREADS_PER_GROUP, 1, 1)]
void InitParticlesCS(uint3 dtid : SV_DispatchThreadID)
{
	const float oneOver = 1.0f / 32.0f;
	//uint idx = dtid.y * 8 + dtid.x;
	uint idx = dtid.x;
	uint count = 0;
	uint stride = 0;
	RWParticleBuffer.GetDimensions(count, stride);

	float3 seed = float3(dtid.x * 0.125f, dtid.y * 0.25f, 0.0f);

	Particle p = initializeParticle(seed, float4(colorScale.xyz, 1.0f));
	p.ttl = PARTICLES_TTL - (((float)dtid.x / (float)count) * PARTICLES_TTL);
	p.color.a = 1.0f;
	RWParticleBuffer[idx] = p;
}

// UPDATE PARTICLES ///////////////////////////////////////////////////////////

//[numthreads(8, 4, 1)]
[numthreads(NUM_THREADS_PER_GROUP, 1, 1)]
void UpdateParticlesCS(uint3 dtid : SV_DispatchThreadID)
{
	//uint idx = dtid.y * 8 + dtid.x;
	uint idx = dtid.x;
	Particle p = RWParticleBuffer[idx];
	float dt = time.z;
	float3 externalForce = emitterDirection.xyz * extraParams.z;
	float3 seed = p.seed;

	// Potentials
	float3 p0 = float3(0.0f, 0.0f, 0.0f);
	float3 p1 = float3(0.0f, 0.0f, 0.0f);
	float3 p2 = float3(0.0f, 0.0f, 0.0f);

	const float3 offset0 = float3(123.4f, 129845.0f, -1239.1f);
	const float3 offset1 = float3(-9519.0f, 9051.0f, -123.0f);
	const int NUM_OCTAVES = 8;
	for(int i = 0; i < NUM_OCTAVES; ++i)
	{
		float gain = 0.5f * pow(2.0f, (float)i);
		float lacunarity = pow(1.98f, (float)i);
		p0 += calcNoiseDerivative(seed * lacunarity) * gain;
		p1 += calcNoiseDerivative((seed + offset0) * lacunarity) * gain;
		p2 += calcNoiseDerivative((seed + offset1) * lacunarity) * gain;
	}

	float3 curlVel = computeCurl(p0, p1, p2);
	curlVel *= 0.3f;

	float3 accel = externalForce;
	p.velocity.xyz += accel * dt;
	p.velocity.xyz += curlVel;
	p.velocity.xyz = normalize(p.velocity.xyz) * extraParams.z * fbm(seed);
	p.position.xyz += p.velocity.xyz * dt;

	float n = noise(seed);// * 0.5f + 0.5f;
	p.color.xyz = colorScale.xyz * n;

	// If the particle is alive, reduce the time to live and kill it if necessary
	p.ttl -= dt;

	//p.color.a = calcParticleScale(p);
	p.color.a -= 0.01f * dt;

	if(p.ttl < 0.0f)
	{
		p.color.a = 1.0f;
		p = initializeParticle(p.seed, p.color);
	}

	RWParticleBuffer[idx] = p;
}

// RENDER PARTICLES ///////////////////////////////////////////////////////////
static const float2 posMultiplier[] = 
{
	float2(	 0.0f,	1.0f),
	float2(	 1.0f,	0.0f),
	float2(	-1.0f,	0.0f),
};

DefaultVertexOut RenderParticlesVS(uint id : SV_VertexID)
{
	DefaultVertexOut output = (DefaultVertexOut)0;
	Particle p = ParticleBuffer[id/3];
	uint idx = id % 3;

	float3 heading = normalize(p.velocity.xyz);
	float3 up = normalize(-CameraDir.xyz);
	float3 right = normalize(cross(heading, up));

	float3 peak = billboardHalfSize.y * posMultiplier[idx].y * heading;
	float3 side = billboardHalfSize.x * posMultiplier[idx].x * right;

	float3 pos = p.position.xyz + peak + side;

	float4x4 m = float4x4(float4(p.color.a, 0.0f, 0.0f, 0.0f),
						  float4(0.0f, p.color.a, 0.0f, 0.0f),
						  float4(0.0f, 0.0f, p.color.a, 0.0f),
						  float4(0.0f, 0.0f, 0.0f, 1.0f));

	output.wpos = mul(float4(pos, 1.0f), transpose(m));
	output.position = mul(mul(output.wpos, View), Projection);
	output.color = float4(p.color.xyz, 1.0f);
	output.normal = heading;
	return output;
}

float4 RenderParticlesPS(DefaultVertexOut input) : SV_Target0
{
	float4 c = float4(input.color.xyz, 1.0f);
	return c;
}

