/*

Connection.cpp

The World's Smallest Web Server Connection implementation.  Handles HTTP 0.9 or
1.n GET requests for files with strict restrictions on the length and content
of the specified path.

See also http://www.w3.org/Protocols/rfc1945/rfc1945.

*/

#include <string.h>
#include <unistd.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include "logging.h"
#include "Connection.h"
#include "EntityType.h"

#define BUFFER_LEN 1024
#define MAX_PATHNAME_LEN 256
#define REPLY_HEADER "HTTP/1.0 %s\r\nServer: twsws/1.1\r\nContent-Length: %d\r\nContent-Type: %s\r\n\r\n"

static const char OK[] = "200 OK";
static const char BAD_REQUEST_MSG[] = "400 Bad request";
static const char NOT_FOUND_MSG[] = "404 Not found";

Connection::Connection(int sock)
{
    m_sock = sock;
}

void Connection::handle_request()
{
    char buffer[BUFFER_LEN];
    memset(buffer, 0, sizeof(buffer));

    // Read at least the first row of the request (probably get it all first
    // read). If first row is longer than buffer then only buffer length bytes
    // are read.
    int length = 0;
    int bytes_read = -1;
    do {
        bytes_read = read(m_sock, buffer + length, sizeof(buffer) - length);
        length += bytes_read;
    } while (!strchr(buffer, '\r') && bytes_read > 0);

    // In case we completely filled the buffer
    buffer[sizeof(buffer) - 1] = 0;

    // Dump the request out of curiosity
    log("HTTP request looks like this:\n%s", buffer);

    // If the buffer didn't include a CR because someone sent us a really big
    // request then reject the request.
    if (strchr(buffer,  '\r') == NULL) {
        log("HTTP request rejected: request line too long\n---");
        m_http_0_9 = false;  // Can't be sure but probably...
        error_reply(BAD_REQUEST_MSG);
        return;
    }

    // See if this is an HTTP 1.0 or greater request header and terminate the
    // request at the end of the path name.
    if (strstr(buffer,  " HTTP/") > 0) {
        strstr(buffer,  " HTTP/")[0] = 0;
        m_http_0_9 = false;
    } else {
        strchr(buffer,  '\r')[0] = 0;
        m_http_0_9 = true;
        log("HTTP 0.9 detected: cool!");
    }

    // Check the request is a GET and point to the path of the file which is
    // being requested.
    char* requested_pathname;
    if (strstr(buffer, "GET /") == buffer) {
        requested_pathname = buffer + 5;
    } else {
        log("HTTP request rejected: request didn't start with 'GET /'\n---");
        error_reply(BAD_REQUEST_MSG);
        return;
    }

    // Check for invalid requests that hackers try it on with such as
    // ../../scripts/ and so on (see httphacks).  Firstly the request length
    // must be < MAX_PATHNAME_LEN characters.  If you need a pathname longer
    // than this switch to Apache!
    if (strlen(requested_pathname) > MAX_PATHNAME_LEN) {
        log("HTTP request rejected: request too long: %d\n---", strlen(requested_pathname));
        error_reply(BAD_REQUEST_MSG);
        return;
    }

    // If the request starts with "/" then the request started with "GET //"
    // and is well not allowed.  Also don't allow "//" as this may map to some
    // local network file.
    if (requested_pathname[0] == '/' || strstr(requested_pathname, "//") > 0) {
        log("HTTP request rejected: found '//'\n---");
        error_reply(BAD_REQUEST_MSG);
        return;
    }

    // Now check each character in the request and make sure that it's one that
    // we find acceptable.  This is well strict but allows normal file serving.
    for (int i = 0; i < strlen(requested_pathname); i++) {
        if (!valid_request_char(requested_pathname[i])) {
            log("HTTP request rejected: invalid char found: %c\n---", requested_pathname[i]);
            error_reply(BAD_REQUEST_MSG);
            return;
        }
    }

    // If the root object has been requested translate to index.html.
    if (requested_pathname[0] == 0) {
        strcpy(requested_pathname, "index.html");
    }

    // There should be only one '.' in the requested pathname just before a
    // valid file extension.  Get the string starting and including the first
    // '.' up to the end of the pathname and check that that string is a valid
    // extension.
    EntityType* et = new EntityType(strchr(requested_pathname, '.'));
    if (!et->is_valid()) {
        log("HTTP request rejected: %s\n---", et->get_reason());
        error_reply(BAD_REQUEST_MSG);
        return;
    }

    // See if the file exists and how big it is.
    struct stat stat_info;
    if (stat(requested_pathname, &stat_info) < 0) {
        log("stat failed for %s\n---", requested_pathname);
        error_reply(NOT_FOUND_MSG);
        return;
    }

    if (!m_http_0_9) {
        // Send header
        send_reply_header(OK, stat_info.st_size, et->get_type());
    }

    // Send entity
    FILE* fp = fopen(requested_pathname, "rb");
    if (fp == NULL) {
        error("fopen failed for %s\n---", requested_pathname);
        error_reply(NOT_FOUND_MSG);
        return;
    }
    log("Sending file: %s", requested_pathname);
    bytes_read = -1;
    while (bytes_read) {
        bytes_read = fread(buffer, 1, sizeof(buffer), fp);
        if (!write_all(buffer, bytes_read)) {
            log("Connection was closed before tranfer completed.");
            break;
        }
    }

    // Cleanup
    fclose(fp);
    close(m_sock);
    log("Done!\n---");
}

void Connection::error_reply(const char* msg)
{
    if (!m_http_0_9) {
        send_reply_header(msg, strlen(msg), TYPE_TEXT_PLAIN);
    }
    write_all(msg, strlen(msg));
    close(m_sock);
} 

void Connection::send_reply_header(const char* msg, int length, char* type)
{
    char header[256];
    sprintf(header, REPLY_HEADER, msg, length, type);
    log("Sending header: %s", header);
    write_all(header, strlen(header));
}

bool Connection::write_all(const char* buffer, const size_t length)
{
    int pos = 0;
    int bytes_written = 0;
    while (pos < length) {
        bytes_written = write(m_sock, buffer + pos, length - pos);
        if (bytes_written <= 0) {
            error("write failed");
            return false; // Terminate thread if mt...
        } 
        pos += bytes_written;
    }
    return true;
}

bool Connection::valid_request_char(char ch)
{
    if (ch >= 'a' && ch <= 'z')
        return true;
    if (ch >= 'A' && ch <= 'Z')
        return true;
    if (ch >= '0' && ch <= '9')
        return true;
    if (ch == '/')
        return true;
    if (ch == '_')
        return true;
    if (ch == '.')
        return true;
    return false;
}