#include <u.h>
#include <libc.h>

enum {
	Fin = 1<<7,
	Opcode = 0x0f,
	Mask = 1<<7,
	Len = 0x7f,
};

typedef struct Socket Socket;
struct Socket {
	int conn;
	int fd;
};

static int debug;
static long delay = 1000;
static char *url;

Socket handshake(void);
void readproc(void *);
void writeproc(void *);
void pingproc(void *);

void
usage(void)
{
	fprint(2, "usage: %s [-d] [-p msecs | -P] url [cmd [args]...]\n", argv0);
	exits("usage");
}

void
main(int argc, char *argv[])
{
	Socket s;
	int pid, wpid, rpid, spid;
	int pin[2], pout[2];

	ARGBEGIN{
	case 'd':
		debug = 1;
		break;
	case 'p':
		delay = atol(EARGF(usage()));
		break;
	case 'P':
		delay = -1;
		break;
	default:
		usage();
	}ARGEND;
	if(argc-- == 0)
		usage();
	url = *argv++;

	s = handshake();

	if(rfork(RFNOTEG) < 0)
		sysfatal("fork: %r");

	if(argc > 0){
		if(pipe(pin) < 0)
			sysfatal("pipe: %r");
		if(pipe(pout) < 0)
			sysfatal("pipe: %r");
		if((wpid = rfork(RFFDG|RFPROC)) < 0)
			sysfatal("fork: %r");
		if(wpid == 0){
			if(dup(pin[0], 0) < 0)
				sysfatal("dup: %r");
			close(pin[0]);
			close(pin[1]);
			if(dup(pout[0], 1) < 0)
				sysfatal("dup: %r");
			close(pout[0]);
			close(pout[1]);
			exec(argv[0], argv);
			if(argv[1][0] != '/')
				exec(smprint("/bin/%s", argv[0]), argv);
			sysfatal("exec %s: %r", argv[0]);
		}
		if(dup(pin[1], 1) < 0)
			sysfatal("dup: %r");
		close(pin[0]);
		close(pin[1]);
		if(dup(pout[1], 0) < 0)
			sysfatal("dup: %r");
		close(pout[0]);
		close(pout[1]);
	}

	if((rpid = rfork(RFPROC)) < 0)
		sysfatal("fork: %r");
	if(rpid == 0){
		readproc(&s);
	}

	if((spid = rfork(RFPROC)) < 0)
		sysfatal("fork: %r");
	if(spid == 0){
		writeproc(&s);
	}

	//if((pid = rfork(RFPROC)) < 0)
	//	sysfatal("fork: %r");
	//if(pid == 0){
	//	pingproc(&s);
	//}

	while((pid = waitpid()) != wpid && pid != spid);
	postnote(PNGROUP, getpid(), "kill");
	exits(nil);
}

Socket
handshake(void)
{
	int ctlfd;
	long n;
	char buf[64];
	Socket s;

	if((ctlfd = open("/mnt/web/clone", ORDWR)) < 0)
		sysfatal("open /mnt/web/clone: %r");
	if((n = read(ctlfd, buf, sizeof(buf)-1)) < 0)
		sysfatal("read /mnt/web/clone: %r");
	buf[n] = '\0';
	s.conn = atoi(buf);
	fprint(ctlfd, "url %s", url);
	fprint(ctlfd, "headers Upgrade: websocket\r\n");
	fprint(ctlfd, "headers Connection: keep-alive, Upgrade\r\n");
	fprint(ctlfd, "headers Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
	fprint(ctlfd, "headers Sec-WebSocket-Version: 13\r\n");
	// close(cfd);

	snprint(buf, sizeof(buf), "/mnt/web/%d/socket", s.conn);
	if((s.fd = open(buf, ORDWR)) < 0)
		sysfatal("open %s: %r", buf);

	// secwebsocketaccept
	// upgrade
	// connection
	//snprint(buf, sizeof(buf), "/mnt/web/%d/connection", s.conn);
	//if((hfd = open(buf, OREAD)) < 0)
	//	sysfatal("open %s: %r", buf);
	//if((n = read(hfd, buf, sizeof(buf))) < 0)
	//	sysfatal("read %s: %r", buf);
	//close(hfd);
	//if(strncmp(buf, "Upgrade", n) != 0)
	//	sysfatal("%s: not Upgrade", buf);

	//snprint(buf, sizeof(buf), "/mnt/web/%d/upgrade", conn);
	//if((hfd = open(buf, OREAD)) < 0)
	//	sysfatal("open %s: %r", buf);
	//if((n = read(hfd, buf, sizeof(buf))) < 0)
	//	sysfatal("read %s: %r", buf);
	//close(hfd);
	//if(strncmp(buf, "websocket", n) != 0)
	//	sysfatal("%s: not websocket", buf);

	// snprint(buf, sizeof(buf), "/mnt/web/%d/secwebsocketaccept", conn);
	// if((hfd = open(buf, OREAD)) < 0)
	// 	sysfatal("open %s: %r", buf);
	// if((n = read(hfd, buf, sizeof(buf))) < 0)
	// 	sysfatal("read %s: %r", buf);
	// close(hfd);

	return s;
}

int
nextn(int fd, void *buf, long nbytes)
{
	char *bufc = buf;
	long total = nbytes;
	long n;
	while(nbytes){
		if((n = readn(fd, bufc, nbytes)) < 0)
			return n;
		bufc += n;
		nbytes -= n;
	}
	return total;
}

void
readframe(Socket *s, int *fin, int *opcode)
{
	uchar frame[10];
	char buf[8 * 1024];
	long i, n, len;
	int mask;
	uchar key[4];

	if((n = readn(s->fd, frame, 2)) < 0)
		sysfatal("read /mnt/web/%d/socket: %r", s->conn);
	if(n == 0){
		if(debug)
			fprint(2, "closing stdout\n");
		close(1);
		exits(0);
	}

	*fin = frame[0] & Fin;

	if(*opcode == 0){
		*opcode = frame[0] & Opcode;
		if(*opcode == 0x0)
			fprint(2, "bogus continuation frame\n");
	}

	len = frame[1] & Len;
	if(len == 126){
		if(nextn(s->fd, frame+2, 2) < 0)
			sysfatal("read /mnt/web/%d/socket: %r", s->conn);
		len = frame[2]<<8 | frame[3];
	}else if(len == 127){
		if(nextn(s->fd, frame+2, 8) < 0)
			sysfatal("read /mnt/web/%d/socket: %r", s->conn);
		for(i = 0, len = 0; i < 8; ++i)
			len = len<<8 | frame[2+i];
	}

	mask = frame[1] & Mask;
	if(mask)
		if(nextn(s->fd, key, 4) < 0)
			sysfatal("read /mnt/web/%d/socket: %r", s->conn);

	if(debug)
		fprint(2, "[recv frame fin=%d opcode=%02x mask=%d len=%ld]", !!(frame[0] & Fin), frame[0] & Opcode, !!mask, len);

	switch(*opcode){
	case 0x1:	/* Text */
	case 0x2:	/* Binary */
		while(len){
			if((n = nextn(s->fd, buf, len > sizeof(buf) ? sizeof(buf) : len)) < 0)
				sysfatal("read /mnt/web/%d/socket: %r", s->conn);
			len -= n;
			if(mask)
				for(i = 0; i < n; ++i)
					buf[i] ^= key[i%4];
			if(write(1, buf, n) < 0)
				sysfatal("write 1: %r");
		}
		break;
	case 0x8:	/* Connection close */
		close(1);
		exits(0);
	case 0x9:	/* Ping */
	case 0xa:	/* Pong */
	default:
		if(nextn(s->fd, buf, len) < 0)
			sysfatal("read /mnt/web/%d/socket: %r", s->conn);
		break;
	}
}

void
readproc(void *a)
{
	Socket *s = a;
	int fin, opcode;

	for(;;){
		for(fin = opcode = 0; !fin;)
			readframe(s, &fin, &opcode);
	}
}

void
writeframe(Socket *s, int opcode, char *payload, long len)
{
	uchar frame[14];
	int n;
	long i;
	ulong key;

	if(debug)
		fprint(2, "[send frame fin=1 opcode=%02x mask=1 len=%ld]", opcode, len);

	frame[0] = Fin | opcode;
	frame[1] = Mask;
	n = 2;
	if(len < 126){
		frame[1] |= len;
		n += 0;
	}else if(len < 65536){
		frame[1] |= 126;
		frame[2] = len>>8 & 0xff;
		frame[3] = len>>0 & 0xff;
		n += 2;
	}else{
		frame[1] |= 127;
		for(i = 0; i < 8; ++i)
			frame[n+i] = len>>(8*(7-i)) & 0xff;
		n += 8;
	}

	key = lrand();
	for(i = 0; i < 4; ++i)
		frame[n+i] = key>>(8*(3-i)) & 0xff;
	n += 4;

	for(i = 0; i < len; ++i)
		payload[i] ^= frame[n-4+i%4];

	if(write(s->fd, frame, n) < 0)
		sysfatal("write /mnt/web/%d/socket: %r", s->conn);
	if(payload && len)
		if(write(s->fd, payload, len) < 0)
			sysfatal("write /mnt/web/%d/socket: %r", s->conn);
}

void
writeproc(void *a)
{
	Socket *s = a;
	char buf[8192];
	long len;

	while((len = read(0, buf, sizeof(buf))) > 0)
		writeframe(s, 0x02, buf, len);
}

void
pingproc(void *a)
{
	Socket *s = a;

	for(;;){
		sleep(delay);
		writeframe(s, 0x09, nil, 0);
	}
}
