#include <errno.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef __APPLE__
#define GL_SILENCE_DEPRECATION
#endif

#include <OpenGL/gl3.h>
#include <GLFW/glfw3.h>
#ifndef __APPLE__
#define GLEW_STATIC
#include <GL/glew.h>
#endif

#include "config.h"

#define STRINGIFY(X) QUOTESTRINGIFY(X)
#define QUOTESTRINGIFY(X) #X

#define GLFW_WINDOW_FRAMEBUFFER 0

GLint
compileprogram(const GLchar *const *vertsrcs, const GLint vertlen, const GLchar *const *fragsrcs, const GLint fraglen)
{
	GLint success;
	GLchar infolog[512];
	GLint program = glCreateProgram();
	if (program == 0) {
		exit(1);
	}
	
	GLuint vert = glCreateShader(GL_VERTEX_SHADER);
	if (vert == 0) {
		exit(1);
	}
	glShaderSource(vert, vertlen, vertsrcs, NULL);
	glCompileShader(vert);
	glGetShaderiv(vert, GL_COMPILE_STATUS, &success);
	if (success == GL_FALSE) {
		glGetShaderInfoLog(vert, sizeof(infolog), NULL, infolog);
		fputs(infolog, stderr);
		exit(1);
	}
	
	GLuint frag = glCreateShader(GL_FRAGMENT_SHADER);
	if (frag == 0) {
		exit(1);
	}
	glShaderSource(frag, fraglen, fragsrcs, NULL);
	glCompileShader(frag);
	glGetShaderiv(frag, GL_COMPILE_STATUS, &success);
	if (!success) {
		glGetShaderInfoLog(frag, sizeof(infolog), NULL, infolog);
		fputs(infolog, stderr);
		exit(1);
	}
	
	glAttachShader(program, vert);
	glAttachShader(program, frag);
	glLinkProgram(program);
	glDeleteShader(vert);
	glDeleteShader(frag);
	glGetProgramiv(program, GL_LINK_STATUS, &success);
	if (success == GL_FALSE) {
		glGetProgramInfoLog(program, sizeof(infolog), NULL, infolog);
		fputs(infolog, stderr);
		exit(1);
	}
	glValidateProgram(program);
	glGetProgramiv(program, GL_VALIDATE_STATUS, &success);
	if (success == GL_FALSE) {
		glGetProgramInfoLog(program, sizeof(infolog), NULL, infolog);
		fputs(infolog, stderr);
		exit(1);
	}
	return program;
}

#define GLSL_LINE(line) "#line "STRINGIFY(line)"\n\n" // Extra \n to line up GLSL with the current file.

const char *versionsrc = "#version 410 core\n";

const char *vertquadsrc =
	GLSL_LINE(__LINE__)
	"void\n"
	"main()\n"
	"{\n"
	"    // (-1,-1), (-1,1), (1,-1), (1,1)\n"
	"    gl_Position = vec4(2*(gl_VertexID/2) - 1, 2*(gl_VertexID%2) - 1, 0, 1);\n"
	"}\n";

const char *fragsdfssrc =
	GLSL_LINE(__LINE__)
	"// Signed distance to a sphere centered at the origin.\n"
	"float\n"
	"spheresd(vec3 pos, float radius)\n"
	"{\n"
	"    return length(pos) - radius;\n"
	"}\n"
	"\n"
	"// Signed distance to an axis-aligned box centered at the origin.\n"
	"float\n"
	"boxsd(vec3 pos, vec3 bounds) {\n"
	"    vec3 q = abs(pos) - bounds;\n"
	"    return length(max(q, 0.0)) + min(max(q.x, max(q.y, q.z)), 0.0);\n"
	"}\n"
	"\n"
	"// Signed distance to a plane.\n"
	"// Normal must be normalized.\n"
	"float\n"
	"planesd(vec3 pos, vec3 normal, float height)\n"
	"{\n"
	"    return dot(pos, normal) + height;\n"
	"}\n"
	"\n"
	"// Signed distance to the entire scene.\n"
	"float\n"
	"subtract(float a, float b) {\n"
	"    return max(a, -b);\n"
	"}\n"
	"\n"
	"float\n"
	"scenesd(vec3 pos)\n"
	"{\n"
	"    return min(\n"
	"        planesd(pos, vec3(0, 1, 0), 0),\n"
	"        spheresd(pos - vec3(0, 0.5, 0), 0.5)\n"
	"    );\n"
	"}\n"
	"\n"
	"// Estimates a normal vector at the specified position.\n"
	"vec3\n"
	"normal(vec3 pos)\n"
	"{\n"
	"    const float epsilon = 0.0001;\n"
	"    vec2 e = vec2(1.0, -1.0)*0.5773;\n"
	"    return normalize(e.xyy*scenesd(pos + e.xyy*epsilon) +\n"
	"                     e.yyx*scenesd(pos + e.yyx*epsilon) +\n"
	"                     e.yxy*scenesd(pos + e.yxy*epsilon) +\n"
	"                     e.xxx*scenesd(pos + e.xxx*epsilon));\n"
	"}\n"
	"\n"
	"float\n"
	"castray(vec3 pos, vec3 dir)\n"
	"{\n"
	"#define RAYMAXITERS 64\n"
	"#define RAYEPSILON 0.00001\n"
	"    float dst = 0;\n"
	"    for(int i = 0; i < RAYMAXITERS; ++i) {\n"
	"        float scenedst = scenesd(pos + dir*dst);\n"
	"        dst += scenedst;\n"
	"        if (scenedst < RAYEPSILON) break;\n"
	"    }\n"
	"    return dst;\n"
	"}\n";

const char *fragpositionmapsrc =
	GLSL_LINE(__LINE__)
	"out vec4 fragcolor;\n"
	"\n"
	"uniform vec2 resolution;\n"
	"uniform mat4 transform;\n"
	"uniform float near;\n"
	"uniform vec3 minbound;\n"
	"uniform vec3 maxbound;\n"
	"\n"
	"void\n"
	"main()\n"
	"{\n"
	"    mat4 local2world = transform;\n"
	"    mat4 world2local = inverse(local2world);\n"
	"    vec3 camerapos = transform[3].xyz;\n"
	"    float a = resolution.x/resolution.y;\n"
	"    \n"
	"    vec3 screenlocal = vec3(\n"
	"        gl_FragCoord.x/resolution.x - 0.5,\n"
	"        (gl_FragCoord.y/resolution.y - 0.5)/a,\n"
	"        near\n"
	"    );\n"
	"    vec3 screenworld = (local2world*vec4(screenlocal, 1)).xyz;\n"
	"    vec3 screendir = normalize(screenworld - camerapos);\n"
	"    \n"
	"    float scenedst = castray(camerapos, screendir);\n"
	"    vec3 sceneworld = camerapos + screendir*scenedst;\n"
	"    vec3 scenelocal = (world2local*vec4(sceneworld, 1)).xyz;\n"
	"    vec3 scenelocalmapped = (scenelocal - minbound) / (maxbound - minbound);\n"
	"    fragcolor = vec4(scenelocalmapped, 1);\n"
	"}\n";

const char *fraglightmapsrc =
	GLSL_LINE(__LINE__)
	"out vec4 fragcolor;\n"
	"\n"
	"uniform vec2 resolution;\n"
	"uniform mat4 transform;\n"
	"uniform float near;\n"
	"uniform uint rays;\n"
	"\n"
	"float\n"
	"rand(inout uint state)\n"
	"{\n"
	"    state = state*747796405 + 2891336453;\n"
	"    uint result = ((state >> ((state >> 28) + 4)) ^ state)*277803737;\n"
	"    result = (result >> 22) ^ result;\n"
	"    return result/4294967295.0;\n"
	"}\n"
	"\n"
	"float\n"
	"randnormal(inout uint state)\n"
	"{\n"
	"    float theta = 2*3.1415926*rand(state);\n"
	"    float rho = sqrt(-2*log(rand(state)));\n"
	"    return rho*cos(theta);\n"
	"}\n"
	"\n"
	"vec3\n"
	"randdir(inout uint state)\n"
	"{\n"
	"    float x = randnormal(state);\n"
	"    float y = randnormal(state);\n"
	"    float z = randnormal(state);\n"
	"    return normalize(vec3(x, y, z));\n"
	"}\n"
	"\n"
	"vec3\n"
	"randhemidir(vec3 normal, inout uint rngstate)\n"
	"{\n"
	"#if 1\n"
	"    // Cosine-weighted distribution.\n"
	"    return normalize(normal + randdir(rngstate));\n"
	"#else\n"
	"    // Uniform distribution.\n"
	"    vec3 dir = randdir(rngstate);\n"
	"    return dir * sign(dot(normal, dir));\n"
	"#endif\n"
	"}\n"
	"\n"
	"float\n"
	"traceray(vec3 pos, vec3 dir, inout uint rngstate)\n"
	"{\n"
	"#define NORMALEPSILON 0.01\n"
	"    float dst = castray(pos, dir);\n"
	"    if (dst > 5) return 1;\n"
	"    pos += dir*dst;\n"
	"    vec3 normal = normal(pos);\n"
	"    dir = randhemidir(normal, rngstate);\n"
	"    // Nudging in the direction of the normal instead of dir helps with the rays\n"
	"    // that run almost parallel to the surface.\n"
	"    vec3 nudge = normal*NORMALEPSILON;\n"
	"    dst = castray(pos + nudge, dir);\n"
	"    if (dst > 5) return 1;\n"
	"    return 0;\n"
	"}\n"
	"\n"
	"void\n"
	"main()\n"
	"{\n"
	"    mat4 local2world = transform;\n"
	"    vec3 camerapos = transform[3].xyz;\n"
	"    uint rngstate = uint(gl_FragCoord.y*resolution.x + gl_FragCoord.x);\n"
	"    float a = resolution.x/resolution.y;\n"
	"    \n"
	"    vec3 screenlocal = vec3(\n"
	"        gl_FragCoord.x/resolution.x - 0.5,\n"
	"        (gl_FragCoord.y/resolution.y - 0.5)/a,\n"
	"        near\n"
	"    );\n"
	"    vec3 screenworld = (local2world*vec4(screenlocal, 1)).xyz;\n"
	"    vec3 screendir = normalize(screenworld - camerapos);\n"
	"    \n"
	"    float acc = 0;\n"
	"    for (int i = 0; i < rays; ++i) {\n"
	"        acc += traceray(camerapos, screendir, rngstate);\n"
	"    }\n"
	"    float lum = acc/rays;\n"
	"    fragcolor = vec4(vec3(lum), 1);\n"
	"}\n";

struct {
	GLint handle;
	struct {
		GLint resolution;
		GLint transform;
		GLint near;
		GLint minbound;
		GLint maxbound;
	} uniforms;
} positionmapprogram;

void
compilepositionmapprogram(void)
{
	const char *vertsrcs[] = { versionsrc, vertquadsrc };
	const char *fragsrcs[] = { versionsrc, fragsdfssrc, fragpositionmapsrc };
	GLint program = compileprogram(vertsrcs, sizeof(vertsrcs)/sizeof(*vertsrcs), fragsrcs, sizeof(fragsrcs)/sizeof(*fragsrcs));
	glUseProgram(program);
	{
		positionmapprogram.uniforms.resolution = glGetUniformLocation(program, "resolution");
		positionmapprogram.uniforms.transform = glGetUniformLocation(program, "transform");
		positionmapprogram.uniforms.near = glGetUniformLocation(program, "near");
		positionmapprogram.uniforms.minbound = glGetUniformLocation(program, "minbound");
		positionmapprogram.uniforms.maxbound = glGetUniformLocation(program, "maxbound");
	}
	positionmapprogram.handle = program;
}

struct {
	GLint handle;
	struct {
		GLint resolution;
		GLint transform;
		GLint near;
		GLint rays;
	} uniforms;
} lightmapprogram;

void
compilelightmapprogram(void)
{
	const char *vertsrcs[] = { versionsrc, vertquadsrc };
	const char *fragsrcs[] = { versionsrc, fragsdfssrc, fraglightmapsrc };
	GLint program = compileprogram(vertsrcs, sizeof(vertsrcs)/sizeof(*vertsrcs), fragsrcs, sizeof(fragsrcs)/sizeof(*fragsrcs));
	glUseProgram(program);
	{
		lightmapprogram.uniforms.resolution = glGetUniformLocation(program, "resolution");
		lightmapprogram.uniforms.transform = glGetUniformLocation(program, "transform");
		lightmapprogram.uniforms.near = glGetUniformLocation(program, "near");
		lightmapprogram.uniforms.rays = glGetUniformLocation(program, "rays");
	}
	lightmapprogram.handle = program;
}

void
compileprograms(void)
{
	compilepositionmapprogram();
	compilelightmapprogram();
}

typedef struct {
	float x, y, z;
} vec3;

vec3
vec3subtract(vec3 a, vec3 b)
{
	return (vec3) {a.x - b.x, a.y - b.y, a.z - b.z};
}

float
vec3dot(vec3 a, vec3 b)
{
	return a.x*b.x + a.y*b.y + a.z*b.z;
}

vec3
vec3normalize(vec3 v)
{
	float len = sqrtf(vec3dot(v, v));
	if (len == 0) return (vec3) {0};
	return (vec3) {v.x/len, v.y/len, v.z/len};
}

vec3
vec3cross(vec3 a, vec3 b)
{
	return (vec3) {
		a.y*b.z - a.z*b.y,
		a.z*b.x - a.x*b.z,
		a.x*b.y - a.y*b.x,
	};
}

typedef struct {
	float m0, m4, m8, m12;
	float m1, m5, m9, m13;
	float m2, m6, m10, m14;
	float m3, m7, m11, m15;
} mat4x4;

mat4x4
lookat(vec3 pos, vec3 target, vec3 up)
{
	vec3 forward = vec3normalize(vec3subtract(target, pos));
	vec3 right = vec3normalize(vec3cross(up, forward));
	up = vec3cross(forward, right);
	return (mat4x4) {
		right.x,   right.y,   right.z,   0,
		up.x,      up.y,      up.z,      0,
		forward.x, forward.y, forward.z, 0,
		pos.x,     pos.y,     pos.z,     1,
	};
}

vec3 pos = {-2, 1.5, 0};
vec3 target = {0, 0.5, 0};
vec3 up = {0, 1, 0};
float near = 0.5;
vec3 positionminbound = {-1, -1, 0.5};
vec3 positionmaxbound = {1, 1, 3.5};

static GLuint outtexture = 0;
static GLuint outframebuffer = 0;

void
initoutframebuffer(void)
{
	glGenTextures(1, &outtexture);
	glActiveTexture(GL_TEXTURE0);
	glBindTexture(GL_TEXTURE_2D, outtexture);
	{
		glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
		glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
		glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, WIDTH, HEIGHT, 0, GL_RGBA, GL_UNSIGNED_BYTE, NULL);
	}
	
	glGenFramebuffers(1, &outframebuffer);
	glBindFramebuffer(GL_FRAMEBUFFER, outframebuffer);
	{
		glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, outtexture, 0);
		
		if (glCheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) {
			fprintf(stderr, "ERROR: could not complete the frame buffer\n");
			exit(1);
		}
	}
}

void
exporttransform(FILE *file)
{
	mat4x4 transform = lookat(pos, target, up);
	fprintf(
		file,
		"mat4x4 transform = {\n"
		"    %f, %f, %f, %f,\n"
		"    %f, %f, %f, %f,\n"
		"    %f, %f, %f, %f,\n"
		"    %f, %f, %f, %f,\n"
		"};\n",
		transform.m0, transform.m4, transform.m8, transform.m12,
		transform.m1, transform.m5, transform.m9, transform.m13,
		transform.m2, transform.m6, transform.m10, transform.m14,
		transform.m3, transform.m7, transform.m11, transform.m15
	);
}

void
exportpositionbounds(FILE *file)
{
	fprintf(file, "vec3 positionminbound = {%f, %f, %f};\n", positionminbound.x, positionminbound.y, positionminbound.z);
	fprintf(file, "vec3 positionmaxbound = {%f, %f, %f};\n", positionmaxbound.x, positionmaxbound.y, positionmaxbound.z);
}

void
exportpositionmap(FILE *file)
{
	fprintf(file, "unsigned char positionmap[WIDTH*HEIGHT*3] = {\n");
	glBindFramebuffer(GL_FRAMEBUFFER, outframebuffer);
	{
		glClear(GL_COLOR_BUFFER_BIT);
		glUseProgram(positionmapprogram.handle);
		{
			glUniform2f(positionmapprogram.uniforms.resolution, WIDTH, HEIGHT);
			mat4x4 transform = lookat(pos, target, up);
			glUniformMatrix4fv(positionmapprogram.uniforms.transform, 1, GL_FALSE, (const GLfloat *)&transform);
			glUniform1f(positionmapprogram.uniforms.near, near);
			glUniform3fv(positionmapprogram.uniforms.minbound, 1, (const GLfloat *)&positionminbound);
			glUniform3fv(positionmapprogram.uniforms.maxbound, 1, (const GLfloat *)&positionmaxbound);
			
			glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
		}
		
		glBindTexture(GL_TEXTURE_2D, outtexture);
		{
			unsigned char pixels[WIDTH*HEIGHT*3];
			glGetTexImage(GL_TEXTURE_2D, 0, GL_RGB, GL_UNSIGNED_BYTE, &pixels);
			for (int y = 0; y < HEIGHT; ++y) {
				fprintf(file, "   ");
				for (int x = 0; x < WIDTH; ++x) {
					for (int comp = 0; comp < 3; ++comp) {
						fprintf(file, " %d,", (unsigned char)pixels[(y*WIDTH + x)*3 + comp]);
					}
				}
				fprintf(file, "\n");
			}
		}
	}
	fprintf(file, "};\n");
}

void
exportlightmap(FILE *file)
{
	fprintf(file, "unsigned char lightmap[WIDTH*HEIGHT] = {\n");
	glBindFramebuffer(GL_FRAMEBUFFER, outframebuffer);
	{
		glClear(GL_COLOR_BUFFER_BIT);
		glUseProgram(lightmapprogram.handle);
		{
			glUniform2f(lightmapprogram.uniforms.resolution, WIDTH, HEIGHT);
			mat4x4 transform = lookat(pos, target, up);
			glUniformMatrix4fv(lightmapprogram.uniforms.transform, 1, GL_FALSE, (const GLfloat *)&transform);
			glUniform1f(lightmapprogram.uniforms.near, near);
			glUniform1ui(lightmapprogram.uniforms.rays, 8192);
			
			glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
		}
		
		glBindTexture(GL_TEXTURE_2D, outtexture);
		{
			unsigned char pixels[WIDTH*HEIGHT];
			glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_UNSIGNED_BYTE, &pixels);
			for (int y = 0; y < HEIGHT; ++y) {
				fprintf(file, "   ");
				for (int x = 0; x < WIDTH; ++x) {
					fprintf(file, " %d,", (unsigned char)pixels[y*WIDTH + x]);
				}
				fprintf(file, "\n");
			}
		}
	}
	fprintf(file, "};\n");
}

void
exportscene(void)
{
	const char *filename = "scene.h";
	FILE *file = fopen(filename, "wb");
	if (file == NULL) {
		printf("ERROR: could not open file %s for writing: %s", filename, strerror(errno));
		exit(1);
	}
	
	fprintf(file, "#ifndef SCENE_H\n");
	fprintf(file, "#define SCENE_H\n");
	fprintf(file, "\n");
	fprintf(file, "#include \"config.h\"\n");
	fprintf(file, "\n");
	fprintf(
		file,
		"typedef struct {\n"
		"    float x, y, z;\n"
		"} vec3;\n"
	);
	fprintf(file, "\n");
	fprintf(
		file,
		"typedef struct {\n"
		"    float m0, m4, m8, m12;\n"
		"    float m1, m5, m9, m13;\n"
		"    float m2, m6, m10, m14;\n"
		"    float m3, m7, m11, m15;\n"
		"} mat4x4;\n"
	);
	fprintf(file, "\n");
	exporttransform(file);
	exportpositionbounds(file);
	exportpositionmap(file);
	exportlightmap(file);
	fprintf(file, "\n");
	fprintf(file, "#endif // SCENE_H\n");
	
	fclose(file);
	
	fprintf(stderr, "Successfully exported %s\n", filename);
}

void
errorcallback(int error, const char *description)
{
	(void) error;
	fputs(description, stderr);
}

void
keycallback(GLFWwindow *window, int key, int scancode, int action, int mods)
{
	(void) window;
	(void) scancode;
	(void) mods;
	if (key == GLFW_KEY_E && action == GLFW_PRESS) {
		exportscene();
	}
}

void
framebuffersizecallback(GLFWwindow *window, int width, int height)
{
	(void) window;
	glViewport(0, 0, width, height);
}

int
main(void)
{
	glfwSetErrorCallback(errorcallback);
	
	if (!glfwInit()) {
		exit(1);
	}
	
#ifndef __APPLE__
	if (glewInit() != GLEW_OK) {
		exit(1);
	}
#endif
	
	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1);
	glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GLFW_TRUE);
	glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
	
	GLFWwindow *window = glfwCreateWindow(WIDTH*2, HEIGHT*2, "Editor", NULL, NULL);
	if (window == NULL) {
		glfwTerminate();
		exit(1);
	}
	
	glfwSetKeyCallback(window, keycallback);
	glfwSetFramebufferSizeCallback(window, framebuffersizecallback);
	
	glfwMakeContextCurrent(window);
	glfwSwapInterval(1); // Used to avoid screen tearing.
	
	GLuint vao;
	glGenVertexArrays(1, &vao);
	glBindVertexArray(vao);
	{
		compileprograms();
		initoutframebuffer();
		
		while (!glfwWindowShouldClose(window)) {
			glBindFramebuffer(GL_FRAMEBUFFER, GLFW_WINDOW_FRAMEBUFFER);
			{
				glClear(GL_COLOR_BUFFER_BIT);
				glUseProgram(lightmapprogram.handle);
				{
					int width, height;
					glfwGetFramebufferSize(window, &width, &height);
					glUniform2f(lightmapprogram.uniforms.resolution, width, height);
					mat4x4 transform = lookat(pos, target, up);
					glUniformMatrix4fv(lightmapprogram.uniforms.transform, 1, GL_FALSE, (const GLfloat *)&transform);
					glUniform1f(lightmapprogram.uniforms.near, near);
					glUniform1ui(lightmapprogram.uniforms.rays, 4);
					
					glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
				}
			}
			
			// TODO: Limit FPS.
			glfwSwapBuffers(window);
			glfwPollEvents();
		}
	}
	glBindVertexArray(0);
	
	glfwDestroyWindow(window);
	glfwTerminate();
}
