about summary refs log tree commit homepage
path: root/ext/unicorn/http11/http11.c
diff options
context:
space:
mode:
Diffstat (limited to 'ext/unicorn/http11/http11.c')
-rw-r--r--ext/unicorn/http11/http11.c117
1 files changed, 76 insertions, 41 deletions
diff --git a/ext/unicorn/http11/http11.c b/ext/unicorn/http11/http11.c
index 021c80b..cd7a8f7 100644
--- a/ext/unicorn/http11/http11.c
+++ b/ext/unicorn/http11/http11.c
@@ -9,6 +9,16 @@
 #include <string.h>
 #include "http11_parser.h"
 
+static http_parser *data_get(VALUE self)
+{
+  http_parser *http;
+
+  Data_Get_Struct(self, http_parser, http);
+  if (!http)
+    rb_raise(rb_eArgError, "NULL found for http when shouldn't be.");
+  return http;
+}
+
 #ifndef RSTRING_PTR
 #define RSTRING_PTR(s) (RSTRING(s)->ptr)
 #endif
@@ -30,9 +40,8 @@ static VALUE global_request_uri;
 static VALUE global_fragment;
 static VALUE global_query_string;
 static VALUE global_http_version;
-static VALUE global_content_length;
 static VALUE global_request_path;
-static VALUE global_content_type;
+static VALUE global_path_info;
 static VALUE global_server_name;
 static VALUE global_server_port;
 static VALUE global_server_protocol;
@@ -45,14 +54,25 @@ static VALUE global_localhost;
 static VALUE global_http;
 
 /** Defines common length and error messages for input length validation. */
-#define DEF_MAX_LENGTH(N,length) const size_t MAX_##N##_LENGTH = length; const char *MAX_##N##_LENGTH_ERR = "HTTP element " # N  " is longer than the " # length " allowed length."
+#define DEF_MAX_LENGTH(N, length) \
+  static const size_t MAX_##N##_LENGTH = length; \
+  static const char * const MAX_##N##_LENGTH_ERR = \
+    "HTTP element " # N  " is longer than the " # length " allowed length."
 
-/** Validates the max length of given input and throws an HttpParserError exception if over. */
-#define VALIDATE_MAX_LENGTH(len, N) if(len > MAX_##N##_LENGTH) { rb_raise(eHttpParserError, MAX_##N##_LENGTH_ERR); }
+/**
+ * Validates the max length of given input and throws an HttpParserError
+ * exception if over.
+ */
+#define VALIDATE_MAX_LENGTH(len, N) do { \
+  if (len > MAX_##N##_LENGTH) \
+    rb_raise(eHttpParserError, MAX_##N##_LENGTH_ERR); \
+} while (0)
 
 /** Defines global strings in the init method. */
-#define DEF_GLOBAL(N, val)   global_##N = rb_obj_freeze(rb_str_new2(val)); rb_global_variable(&global_##N)
-
+#define DEF_GLOBAL(N, val) do { \
+  global_##N = rb_obj_freeze(rb_str_new(val, sizeof(val) - 1)); \
+  rb_global_variable(&global_##N); \
+} while (0)
 
 /* Defines the maximum allowed lengths for various input elements.*/
 DEF_MAX_LENGTH(FIELD_NAME, 256);
@@ -153,14 +173,11 @@ static void http_field(void *data, const char *field,
                        size_t flen, const char *value, size_t vlen)
 {
   VALUE req = (VALUE)data;
-  VALUE v = Qnil;
   VALUE f = Qnil;
 
   VALIDATE_MAX_LENGTH(flen, FIELD_NAME);
   VALIDATE_MAX_LENGTH(vlen, FIELD_VALUE);
 
-  v = rb_str_new(value, vlen);
-
   f = find_common_field_value(field, flen);
 
   if (f == Qnil) {
@@ -179,9 +196,11 @@ static void http_field(void *data, const char *field,
     memcpy(RSTRING_PTR(f) + HTTP_PREFIX_LEN, field, flen);
     assert(*(RSTRING_PTR(f) + RSTRING_LEN(f)) == '\0'); /* paranoia */
     /* fprintf(stderr, "UNKNOWN HEADER <%s>\n", RSTRING_PTR(f)); */
+  } else if (f == global_http_host && rb_hash_aref(req, f) != Qnil) {
+    return;
   }
 
-  rb_hash_aset(req, f, v);
+  rb_hash_aset(req, f, rb_str_new(value, vlen));
 }
 
 static void request_method(void *data, const char *at, size_t length)
@@ -193,6 +212,16 @@ static void request_method(void *data, const char *at, size_t length)
   rb_hash_aset(req, global_request_method, val);
 }
 
+static void scheme(void *data, const char *at, size_t length)
+{
+  rb_hash_aset((VALUE)data, global_rack_url_scheme, rb_str_new(at, length));
+}
+
+static void host(void *data, const char *at, size_t length)
+{
+  rb_hash_aset((VALUE)data, global_http_host, rb_str_new(at, length));
+}
+
 static void request_uri(void *data, const char *at, size_t length)
 {
   VALUE req = (VALUE)data;
@@ -202,6 +231,13 @@ static void request_uri(void *data, const char *at, size_t length)
 
   val = rb_str_new(at, length);
   rb_hash_aset(req, global_request_uri, val);
+
+  /* "OPTIONS * HTTP/1.1\r\n" is a valid request */
+  if (length == 1 && *at == '*') {
+    val = rb_str_new(NULL, 0);
+    rb_hash_aset(req, global_request_path, val);
+    rb_hash_aset(req, global_path_info, val);
+  }
 }
 
 static void fragment(void *data, const char *at, size_t length)
@@ -224,6 +260,10 @@ static void request_path(void *data, const char *at, size_t length)
 
   val = rb_str_new(at, length);
   rb_hash_aset(req, global_request_path, val);
+
+  /* rack says PATH_INFO must start with "/" or be empty */
+  if (!(length == 1 && *at == '*'))
+    rb_hash_aset(req, global_path_info, val);
 }
 
 static void query_string(void *data, const char *at, size_t length)
@@ -252,22 +292,32 @@ static void header_done(void *data, const char *at, size_t length)
   VALUE server_port = global_port_80;
   VALUE temp;
 
+  /* rack requires QUERY_STRING */
+  if (rb_hash_aref(req, global_query_string) == Qnil)
+    rb_hash_aset(req, global_query_string, rb_str_new(NULL, 0));
+
   /* set rack.url_scheme to "https" or "http", no others are allowed by Rack */
-  if ((temp = rb_hash_aref(req, global_http_x_forwarded_proto)) != Qnil &&
-      RSTRING_LEN(temp) == 5 &&
-      !memcmp("https", RSTRING_PTR(temp), 5))
+  if ((temp = rb_hash_aref(req, global_rack_url_scheme)) == Qnil) {
+    if ((temp = rb_hash_aref(req, global_http_x_forwarded_proto)) != Qnil &&
+        RSTRING_LEN(temp) == 5 &&
+        !memcmp("https", RSTRING_PTR(temp), 5))
+      server_port = global_port_443;
+    else
+      temp = global_http;
+    rb_hash_aset(req, global_rack_url_scheme, temp);
+  } else if (RSTRING_LEN(temp) == 5 && !memcmp("https", RSTRING_PTR(temp), 5)) {
     server_port = global_port_443;
-  else
-    temp = global_http;
-  rb_hash_aset(req, global_rack_url_scheme, temp);
+  }
 
   /* parse and set the SERVER_NAME and SERVER_PORT variables */
   if ((temp = rb_hash_aref(req, global_http_host)) != Qnil) {
     char *colon = memchr(RSTRING_PTR(temp), ':', RSTRING_LEN(temp));
     if (colon) {
+      long port_start = colon - RSTRING_PTR(temp) + 1;
+
       server_name = rb_str_substr(temp, 0, colon - RSTRING_PTR(temp));
-      server_port = rb_str_substr(temp, colon - RSTRING_PTR(temp)+1,
-                                  RSTRING_LEN(temp));
+      if ((RSTRING_LEN(temp) - port_start) > 0)
+        server_port = rb_str_substr(temp, port_start, RSTRING_LEN(temp));
     } else {
       server_name = temp;
     }
@@ -294,14 +344,6 @@ static VALUE HttpParser_alloc(VALUE klass)
   VALUE obj;
   http_parser *hp = ALLOC_N(http_parser, 1);
   TRACE();
-  hp->http_field = http_field;
-  hp->request_method = request_method;
-  hp->request_uri = request_uri;
-  hp->fragment = fragment;
-  hp->request_path = request_path;
-  hp->query_string = query_string;
-  hp->http_version = http_version;
-  hp->header_done = header_done;
   http_parser_init(hp);
 
   obj = Data_Wrap_Struct(klass, NULL, HttpParser_free, hp);
@@ -318,9 +360,7 @@ static VALUE HttpParser_alloc(VALUE klass)
  */
 static VALUE HttpParser_init(VALUE self)
 {
-  http_parser *http = NULL;
-  DATA_GET(self, http_parser, http);
-  http_parser_init(http);
+  http_parser_init(data_get(self));
 
   return self;
 }
@@ -335,9 +375,7 @@ static VALUE HttpParser_init(VALUE self)
  */
 static VALUE HttpParser_reset(VALUE self)
 {
-  http_parser *http = NULL;
-  DATA_GET(self, http_parser, http);
-  http_parser_init(http);
+  http_parser_init(data_get(self));
 
   return Qnil;
 }
@@ -358,12 +396,10 @@ static VALUE HttpParser_reset(VALUE self)
 
 static VALUE HttpParser_execute(VALUE self, VALUE req_hash, VALUE data)
 {
-  http_parser *http;
+  http_parser *http = data_get(self);
   char *dptr = RSTRING_PTR(data);
   long dlen = RSTRING_LEN(data);
 
-  DATA_GET(self, http_parser, http);
-
   if (http->nread < dlen) {
     http->data = (void *)req_hash;
     http_parser_execute(http, dptr, dlen);
@@ -378,9 +414,8 @@ static VALUE HttpParser_execute(VALUE self, VALUE req_hash, VALUE data)
   rb_raise(eHttpParserError, "Requested start is after data buffer end.");
 }
 
-void Init_http11()
+void Init_http11(void)
 {
-
   mUnicorn = rb_define_module("Unicorn");
 
   DEF_GLOBAL(rack_url_scheme, "rack.url_scheme");
@@ -390,13 +425,11 @@ void Init_http11()
   DEF_GLOBAL(query_string, "QUERY_STRING");
   DEF_GLOBAL(http_version, "HTTP_VERSION");
   DEF_GLOBAL(request_path, "REQUEST_PATH");
-  DEF_GLOBAL(content_length, "CONTENT_LENGTH");
-  DEF_GLOBAL(content_type, "CONTENT_TYPE");
+  DEF_GLOBAL(path_info, "PATH_INFO");
   DEF_GLOBAL(server_name, "SERVER_NAME");
   DEF_GLOBAL(server_port, "SERVER_PORT");
   DEF_GLOBAL(server_protocol, "SERVER_PROTOCOL");
   DEF_GLOBAL(server_protocol_value, "HTTP/1.1");
-  DEF_GLOBAL(http_host, "HTTP_HOST");
   DEF_GLOBAL(http_x_forwarded_proto, "HTTP_X_FORWARDED_PROTO");
   DEF_GLOBAL(port_80, "80");
   DEF_GLOBAL(port_443, "443");
@@ -412,4 +445,6 @@ void Init_http11()
   rb_define_method(cHttpParser, "execute", HttpParser_execute,2);
   sym_http_body = ID2SYM(rb_intern("http_body"));
   init_common_fields();
+  global_http_host = find_common_field_value("HOST", 4);
+  assert(global_http_host != Qnil);
 }