from datatypes import Int, Float, Floats, Ref, SIZEOF_FLOAT, Str, Chars, Ints, Shorts, SIZEOF_SHORT
import gl
import gl2
from PIL import Image
import ctypes

class ShaderCompilationFailure(Exception):
	pass

class Shader(object):
	def __init__(self, gl, vert, frag, attrs=("position,")):
		print "loading shader"
		self.gl = gl
		self.vert = vert
		self.frag = frag
		self.attributes = attrs
		self.setup()
		self.build()
		print "loaded shader"

	@staticmethod
	def from_files(gl, vert, frag, attrs=("position,")):
		return Shader(gl, open("shaders/" + vert).read(),
				open("shaders/" + frag).read(), attrs)

	def _compile(self, src, type_):
		src = Str(src)
		shader = self.gl.glCreateShader(type_)
		self.gl.glShaderSource(shader, 1, Ref(src), 0)
		self.gl.glCompileShader(shader)
		self._verify_compile(shader, type_)
		return shader

	def _verify_compile(self, shader, type_):
		typestr = "FRAGMENT" if type_ == gl2.GL_FRAGMENT_SHADER else "VERTEX"
		status = Int()
		self.gl.glGetShaderiv(shader, gl2.GL_COMPILE_STATUS, Ref(status))

		if status.value == 0:
			infolen = Int()
			self.gl.glGetShaderiv(shader, gl2.GL_INFO_LOG_LENGTH, Ref(infolen))

			output = Str(" " * infolen.value)
			self.gl.glGetShaderInfoLog(shader, infolen, 0, output)
			raise ShaderCompilationFailure("FAILED TO COMPILE %s SHADER: %s" % (typestr, output.value))

		N = 10240
		output = Str(" " * N)
		self.gl.glGetShaderInfoLog(shader, N, 0, output)
		if output.value != "":
			print "%s SHADER COMPILE WARNINGS:\n%s" % (typestr, output.value)

	def build(self):
		vertprg = self._compile(self.vert, gl2.GL_VERTEX_SHADER)
		fragprg = self._compile(self.frag, gl2.GL_FRAGMENT_SHADER)

		self.glname = self.gl.glCreateProgram()
		self.gl.glAttachShader(self.glname, vertprg)
		self.gl.glAttachShader(self.glname, fragprg)

		for i, att in enumerate(self.attributes):
			self.gl.glBindAttribLocation(self.glname, i, att)

		self.gl.glLinkProgram(self.glname)
		self._verify_link()

	def _verify_link(self):
		status = Int()
		self.gl.glGetProgramiv(self.glname, gl2.GL_LINK_STATUS, Ref(status))

		if status.value == 0:
			infolen = Int()
			self.gl.glGetProgramiv(self.glname, gl2.GL_INFO_LOG_LENGTH, Ref(infolen))
			output = Str(" " * infolen.value)
			self.gl.glGetProgramInfoLog(self.glname, infolen, 0, output)
			raise ShaderCompilationFailure("FAILED TO LINK SHADER PROGRAM: " + output.value)
	
	def setup(self):
		self.unifsf = [
			self.gl.glUniform1f,
			self.gl.glUniform2f,
			self.gl.glUniform3f,
			self.gl.glUniform4f
		]
		self.unifsi = [
			self.gl.glUniform1i,
			self.gl.glUniform2i,
			self.gl.glUniform3i,
			self.gl.glUniform4i
		]
		self.unifsv = [
			self.gl.glUniform1fv,
			self.gl.glUniform2fv,
			self.gl.glUniform3fv,
			self.gl.glUniform4fv
		]
		self.unifsmatfv = [
			self.gl.glUniformMatrix2fv,
			self.gl.glUniformMatrix3fv,
			self.gl.glUniformMatrix4fv,
		]

	def use(self):
		self.gl.glUseProgram(self.glname)

	def uniform(self, name, val):
		if type(val) != tuple and type(val) != list:
			val = (val,)

		if type(val[0]) == float:
			uni = self.unifsf
			val = map(Float, val)
		else:
			uni = self.unifsi
			val = map(Int, val)

		loc = self.gl.glGetUniformLocation(self.glname, name)
		args = [loc] + val
		uni[len(val) - 1](*args)
	
	def uniforms(self, name, val):
		raise Exception("TODO implement me, uniformXfv")

	def uniformmat(self, name, val):
		if len(val) == 4:
			uni = self.unifsmatfv[0]
		elif len(val) == 9:
			uni = self.unifsmatfv[1]
		elif len(val) == 16:
			uni = self.unifsmatfv[2]

		val = Floats(val)

		loc = self.gl.glGetUniformLocation(self.glname, name)
		uni(loc, 1, gl2.GL_FALSE, val)

	def texture(self, name, texid, tex):
		self.uniform(name, texid)
		tex.bind(texid)

class VboStrip(object):
	def __init__(self, gl_, coords):
		self.gl = gl_
		self.coords = Floats(coords)
		self.indices = len(coords)
		self.glname = Int()

		self.gl.glGenBuffers(1, Ref(self.glname))
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.glname)
		self.gl.glEnableVertexAttribArray(0) # bind this with a shader? bindattriblocation
		self.gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, 0)
		self.gl.glBufferData(gl.GL_ARRAY_BUFFER, SIZEOF_FLOAT * len(self.coords), self.coords, gl.GL_STATIC_DRAW)
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
	
	def draw(self):
		wtf = True
		if wtf: # wtf why?
			self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
			self.gl.glEnableVertexAttribArray(0)
			self.gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, self.coords)
		else: # no wtf, but no work
			self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.glname)
		self.gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.coords))
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)

class quad(object):
	def __init__(self, gl_, x0, y0, x1, y1):
		"""x0,y0 bottom left, x1,y1 top right"""
		self.gl = gl_
		self._setup(x0, y0, x1, y1)

	def _setup(self, x0, y0, x1, y1):
		self.coords = Floats([
			x0, y0, 1.0,
			x1, y0, 1.0,
			x1, y1, 1.0,

			x0, y0, 1.0,
			x1, y1, 1.0,
			x0, y1, 1.0,
		])
		self.uvs = Floats([
			-1.0, -1.0,
			1.0, -1.0,
			1.0, 1.0,

			-1.0, -1.0,
			1.0, 1.0,
			-1.0, 1.0,
		])

	def draw(self):
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
		self.gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, 0)
		self.gl.glEnableVertexAttribArray(0)
		self.gl.glEnableVertexAttribArray(1)
		self.gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, self.coords)
		self.gl.glVertexAttribPointer(1, 2, gl.GL_FLOAT, gl2.GL_FALSE, 0, self.uvs)
		self.gl.glDrawArrays(gl.GL_TRIANGLES, 0, len(self.coords))
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)

class Texture(object):
	def __init__(self, gl_, fmt, size, rgba_data=0):
		self.gl = gl_
		self.internal_format = fmt
		self.size = size
		self.rgba_data = rgba_data
		self._allocate()
	
	def _allocate(self):
		self.glname = Int()
		self.gl.glGenTextures(1, Ref(self.glname))
		self.gl.glBindTexture(gl.GL_TEXTURE_2D, self.glname)
		self.gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, self.internal_format,
				self.size[0], self.size[1], 0, gl2.GL_RGBA, gl.GL_UNSIGNED_BYTE, self.rgba_data)
		self.gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
		self.gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
		self.gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_REPEAT)
		self.gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_REPEAT)

	def blit(self, pos, scale):
		raise Exception("TODO: this would be nice")

	def bind(self, i):
		self.gl.glActiveTexture(gl.GL_TEXTURE0 + i)
		self.gl.glBindTexture(gl.GL_TEXTURE_2D, self.glname)

class Picture(Texture):
	def __init__(self, gl_, filename):
		print "loading texture"
		size, pixs = self._file_to_bytes("pix/" + filename + ".png")
		Texture.__init__(self, gl_, gl2.GL_RGBA, size, Chars(pixs))
		print "loaded texture"

	def _file_to_bytes(self, filename):
		im = Image.open(filename)
		im = im.convert("RGBA")
		#im = im.resize((256, 256), Image.BILINEAR)
		im = im.transpose(Image.FLIP_TOP_BOTTOM)
		pixtuples = list(im.getdata())
		pixs = []
		for tup in pixtuples:
			pixs.extend(tup)
		return (im.size, pixs)

class Rtt(object):
	def __init__(self, gl_, resolution):
		self.gl = gl_
		self.resolution = resolution
		self._build()

	def _build(self):
		# depth buffer
		#self.depthname = Int()
		#self.gl.glGenRenderbuffers(1, Ref(self.depthname))
		#self.gl.glBindRenderbuffer(gl2.GL_RENDERBUFFER, self.depthname)
		#self.gl.glRenderbufferStorage(gl2.GL_RENDERBUFFER, gl2.GL_DEPTH_COMPONENT,
		#		self.resolution[0], self.resolution[1])

		# frame buffer
		self.fboname = Int()
		self.gl.glGenFramebuffers(1, Ref(self.fboname))
		self.gl.glBindFramebuffer(gl2.GL_FRAMEBUFFER, self.fboname)
		#self.gl.glFramebufferRenderbuffer(gl2.GL_FRAMEBUFFER, gl2.GL_DEPTH_ATTACHMENT, gl2.GL_RENDERBUFFER, self.depthname)

		# color texture
		segfault_preventer = Chars((0,)) # (ctypes.whatever*lots)(big list) is *slow* on raspi
		ctypes.resize(segfault_preventer, 4 * self.resolution[0]*self.resolution[1])
		self.diffusetex = Texture(self.gl, gl2.GL_RGBA, self.resolution, segfault_preventer) # do i want gl_nearest?
		# (GL_RGBA32F, GL_FLOAT) for positions? (GL_RGBA16F, GL_FLOAT) for normals?

		# attach color tex
		self.gl.glFramebufferTexture2D(gl2.GL_FRAMEBUFFER, gl2.GL_COLOR_ATTACHMENT0, gl2.GL_TEXTURE_2D, self.diffusetex.glname, 0)
		# gles says no
		#buffers = Int(gl2.GL_COLOR_ATTACHMENT0)
		#self.gl.glDrawBuffers(1, Ref(buffers))

		# sanity check in addition to glgeterrors
		if self.gl.glCheckFramebufferStatus(gl2.GL_FRAMEBUFFER) != gl2.GL_FRAMEBUFFER_COMPLETE:
			raise Exception("wat fbo fail")

		self.gl.glBindFramebuffer(gl2.GL_FRAMEBUFFER, 0)


	def attach(self):
		self.gl.glBindFramebuffer(gl2.GL_FRAMEBUFFER, self.fboname)
		self.gl.glViewport(0, 0, self.resolution[0], self.resolution[1])

	def detach(self):
		self.gl.glBindFramebuffer(gl2.GL_FRAMEBUFFER, 0)

	def attach_sample(self):
		# as with a generic texture (should i inherit? basetexture?)
		pass
		# active
		# bind
		# setuniform


class Mesh(object):
	def __init__(self, gl_, vertices, normals, faces):
		print "Loading mesh"
		self.gl = gl_
		self.vertices = Floats(vertices)
		self.normals = Floats(normals)
		self.both = Floats(vertices + normals)
		self.faces = Shorts(faces)
		self._setup()
		print "Loaded mesh"
	
	def render2(self):
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)

		self.gl.glEnableVertexAttribArray(0)
		self.gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, self.vertices)

		self.gl.glEnableVertexAttribArray(1)
		self.gl.glVertexAttribPointer(1, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, self.normals)
		self.gl.glDrawElements(gl.GL_TRIANGLES, len(self.faces), gl2.GL_UNSIGNED_SHORT, self.faces)

	def _setup(self):
		self.glname = Int()
		self.gl.glGenBuffers(1, Ref(self.glname))
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.glname)
		self.gl.glBufferData(gl.GL_ARRAY_BUFFER, SIZEOF_FLOAT * len(self.both), self.both, gl.GL_STATIC_DRAW)
		self.gl.glEnableVertexAttribArray(0) # bind this with a shader? bindattriblocation
		self.gl.glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, 0) # or 3 for stride
		self.gl.glEnableVertexAttribArray(1)
		self.gl.glVertexAttribPointer(1, 3, gl.GL_FLOAT, gl2.GL_FALSE, 0, SIZEOF_FLOAT*len(self.vertices))

		self.glname2 = Int()
		self.gl.glGenBuffers(1, Ref(self.glname2))
		self.gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, self.glname2)
		self.gl.glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, SIZEOF_SHORT * len(self.faces), self.faces, gl.GL_STATIC_DRAW)
	
		self.gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
		self.gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, 0)

	def render(self):
		self.gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, self.glname2)
		self.gl.glEnableVertexAttribArray(0)
		self.gl.glEnableVertexAttribArray(1)
		self.gl.glDrawElements(gl.GL_TRIANGLES, len(self.faces), gl2.GL_UNSIGNED_SHORT, 0)
		self.gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, 0)

	@staticmethod
	def load_wavefront(gl, filelike):
		print "Loading obj"
		verts = [] # x,y,z, x,y,z,
		norms = [] # x,y,z, x,y,z,
		faces = [] # v,n, v,n,
		norm_per_vert = {}

		for line in filelike:
			if "#" in line:
				line = line[0:line.find("#")]

			line = line.strip()
			if line == "":
				continue

			what, line = line.split(" ", 1)
			line = line.lstrip()

			if what == "v":
				vertex = map(float, line.split(" "))
				verts.extend(vertex)
			elif what == "vn":
				normal = map(float, line.split(" "))
				norms.extend(normal)
			elif what == "f":
				for node in line.split(" "):
					v, t, n = node.split("/")
					# need to shuffle normals to be in the same order
					faces.append(int(v) - 1)
					norm_per_vert[int(v) - 1] = int(n) - 1

		norms_ordered = []
		for v, n in sorted(norm_per_vert.items()):
			norms_ordered.extend((norms[3*n], norms[3*n+1], norms[3*n+2]))
		print "Loaded obj"
		return Mesh(gl, verts, norms_ordered, faces)

