about summary refs log tree commit homepage
diff options
context:
space:
mode:
authorEric Wong <normalperson@yhbt.net>2010-09-26 01:41:47 +0000
committerEric Wong <normalperson@yhbt.net>2010-09-26 02:00:10 +0000
commit43153f218e14cad3a2c6f4056fcf02dc49dc4b36 (patch)
treed363a76c58e106bc1f4eaf4cc7ad91bee21bf5df
parent81d66a794338e241e00b9ffd66fc94b80064475d (diff)
downloadsleepy_penguin-43153f218e14cad3a2c6f4056fcf02dc49dc4b36.tar.gz
It's dangerous to preserve epoll descriptors across fork,
especially in Ruby where the GC can invalidate objects at any
time.  Installing pthread_atfork hooks prevents VALUE references
stored in the kernel from leaking across process boundaries,
making it far more difficult for a sanely written application to
leak invalid VALUEs to the user.
-rw-r--r--ext/sleepy_penguin/epoll.c106
-rw-r--r--ext/sleepy_penguin/extconf.rb3
-rw-r--r--test/test_epoll.rb32
3 files changed, 129 insertions, 12 deletions
diff --git a/ext/sleepy_penguin/epoll.c b/ext/sleepy_penguin/epoll.c
index d4610d8..14eb860 100644
--- a/ext/sleepy_penguin/epoll.c
+++ b/ext/sleepy_penguin/epoll.c
@@ -1,10 +1,23 @@
 #include "sleepy_penguin.h"
 #include <sys/epoll.h>
+#include <pthread.h>
 
 #ifndef EPOLL_CLOEXEC
 #  define EPOLL_CLOEXEC (int)(02000000)
 #endif
 
+#define EP_RECREATE (-2)
+
+#ifndef HAVE_RB_MEMERROR
+static void rb_memerror(void)
+{
+        static const char e[] = "[FATAL] failed to allocate memory\n"
+        write(2, e, sizeof(e) - 1);
+        abort();
+}
+#endif
+
+static st_table *active;
 static const int step = 64; /* unlikely to grow unless you're huge */
 static VALUE cEpoll_IO;
 static ID id_for_fd;
@@ -26,6 +39,7 @@ struct rb_epoll {
         int capa;
         struct epoll_event *events;
         VALUE io;
+        int flags;
 };
 
 static struct rb_epoll *ep_get(VALUE self)
@@ -83,6 +97,10 @@ static void gcfree(void *ptr)
         struct rb_epoll *ep = ptr;
 
         xfree(ep->events);
+        if (ep->fd >= 0) {
+                st_data_t key = ep->fd;
+                st_delete(active, &key, NULL);
+        }
         if (NIL_P(ep->io) && ep->fd >= 0) {
                 /* can't raise during GC */
                 (void)close(ep->fd);
@@ -107,6 +125,29 @@ static VALUE alloc(VALUE klass)
         return self;
 }
 
+static void my_epoll_create(struct rb_epoll *ep)
+{
+        ep->fd = epoll_create1(ep->flags);
+
+        if (ep->fd == -1) {
+                if (errno == EMFILE || errno == ENFILE || errno == ENOMEM) {
+                        rb_gc();
+                        ep->fd = epoll_create1(ep->flags);
+                }
+                if (ep->fd == -1)
+                        rb_sys_fail("epoll_create1");
+        }
+        st_insert(active, (st_data_t)ep->fd, (st_data_t)ep);
+}
+
+static void ep_check(struct rb_epoll *ep)
+{
+        if (ep->fd == EP_RECREATE)
+                my_epoll_create(ep);
+        if (ep->fd == -1)
+                rb_raise(rb_eIOError, "closed");
+}
+
 /*
  * creates a new Epoll object with an optional +flags+ argument.
  * +flags+ may currently be +Epoll::CLOEXEC+ or 0 (or nil)
@@ -130,16 +171,8 @@ static VALUE init(int argc, VALUE *argv, VALUE self)
                         rb_raise(rb_eArgError, "flags must be an integer");
                 }
         }
-
-        ep->fd = epoll_create1(flags);
-        if (ep->fd == -1) {
-                if (errno == EMFILE || errno == ENFILE || errno == ENOMEM) {
-                        rb_gc();
-                        ep->fd = epoll_create1(flags);
-                }
-                if (ep->fd == -1)
-                        rb_sys_fail("epoll_create1");
-        }
+        ep->flags = flags;
+        my_epoll_create(ep);
 
         return self;
 }
@@ -151,6 +184,7 @@ static VALUE ctl(VALUE self, VALUE io, VALUE flags, int op)
         int fd = my_fileno(io);
         int rv;
 
+        ep_check(ep);
         event.events = NUM2UINT(flags);
         pack_event_data(&event, io);
 
@@ -176,6 +210,7 @@ static VALUE set(VALUE self, VALUE io, VALUE flags)
         int fd = my_fileno(io);
         int rv;
 
+        ep_check(ep);
         event.events = NUM2UINT(flags);
         pack_event_data(&event, io);
 
@@ -341,6 +376,7 @@ static VALUE epwait(int argc, VALUE *argv, VALUE self)
         VALUE timeout, maxevents;
         struct rb_epoll *ep = ep_get(self);
 
+        ep_check(ep);
         rb_need_block();
         rb_scan_args(argc, argv, "02", &maxevents, &timeout);
         ep->timeout = NIL_P(timeout) ? -1 : NUM2INT(timeout);
@@ -376,6 +412,8 @@ static VALUE to_io(VALUE self)
 {
         struct rb_epoll *ep = ep_get(self);
 
+        ep_check(ep);
+
         if (NIL_P(ep->io))
                 ep->io = rb_funcall(cEpoll_IO, id_for_fd, 1, INT2NUM(ep->fd));
 
@@ -386,8 +424,15 @@ static VALUE epclose(VALUE self)
 {
         struct rb_epoll *ep = ep_get(self);
 
+        if (ep->fd >= 0) {
+                st_data_t key = ep->fd;
+                st_delete(active, &key, NULL);
+        }
+
         if (NIL_P(ep->io)) {
-                if (ep->fd < 0) {
+                if (ep->fd == EP_RECREATE) {
+                        ep->fd = -1;
+                } else if (ep->fd == -1) {
                         rb_raise(rb_eIOError, "closed");
                 } else {
                         int e = close(ep->fd);
@@ -408,7 +453,7 @@ static VALUE epclosed(VALUE self)
 {
         struct rb_epoll *ep = ep_get(self);
 
-        return ep->fd < 0 ? Qtrue : Qfalse;
+        return ep->fd == -1 ? Qtrue : Qfalse;
 }
 
 static VALUE init_copy(VALUE copy, VALUE orig)
@@ -419,6 +464,7 @@ static VALUE init_copy(VALUE copy, VALUE orig)
         assert(a->events && b->events && a->events != b->events &&
                NIL_P(b->io) && "Ruby broken?");
 
+        ep_check(a);
         b->fd = dup(a->fd);
         if (b->fd == -1) {
                 if (errno == ENFILE || errno == EMFILE) {
@@ -428,10 +474,39 @@ static VALUE init_copy(VALUE copy, VALUE orig)
                 if (b->fd == -1)
                         rb_sys_fail("dup");
         }
+        st_insert(active, (st_data_t)b->fd, (st_data_t)b);
 
         return copy;
 }
 
+/*
+ * we close (or lose to GC) epoll descriptors at fork to avoid leakage
+ * and invalid objects being referenced later in the child
+ */
+static int ep_atfork(st_data_t key, st_data_t value, void *ignored)
+{
+        struct rb_epoll *ep = (struct rb_epoll *)value;
+
+        if (NIL_P(ep->io)) {
+                if (ep->fd >= 0)
+                        (void)close(ep->fd);
+        } else {
+                ep->io = Qnil; /* must let GC take care of it later :< */
+        }
+        ep->fd = EP_RECREATE;
+
+        return ST_CONTINUE;
+}
+
+static void atfork_child(void)
+{
+        st_table *old = active;
+
+        active = st_init_numtable();
+        st_foreach(old, ep_atfork, (st_data_t)NULL);
+        st_free_table(old);
+}
+
 void sleepy_penguin_init_epoll(void)
 {
         VALUE mSleepyPenguin, cEpoll;
@@ -460,4 +535,11 @@ void sleepy_penguin_init_epoll(void)
         rb_define_const(cEpoll, "ET", INT2NUM(EPOLLET));
         rb_define_const(cEpoll, "ONESHOT", INT2NUM(EPOLLONESHOT));
         id_for_fd = rb_intern("for_fd");
+        active = st_init_numtable();
+
+        if (pthread_atfork(NULL, NULL, atfork_child) != 0) {
+                rb_gc();
+                if (pthread_atfork(NULL, NULL, atfork_child) != 0)
+                        rb_memerror();
+        }
 }
diff --git a/ext/sleepy_penguin/extconf.rb b/ext/sleepy_penguin/extconf.rb
index c29b35d..ad24aba 100644
--- a/ext/sleepy_penguin/extconf.rb
+++ b/ext/sleepy_penguin/extconf.rb
@@ -1,9 +1,12 @@
 require 'mkmf'
 have_header('sys/epoll.h') or abort 'sys/epoll.h not found'
+have_header("pthread.h") or abort 'pthread.h not found'
 have_header('sys/eventfd.h')
 have_header('sys/signalfd.h')
 have_header('sys/timerfd.h')
+have_func('rb_memerror')
 have_func('epoll_create1', %w(sys/epoll.h))
 have_func('rb_thread_blocking_region')
+have_library('pthread')
 dir_config('sleepy_penguin')
 create_makefile('sleepy_penguin_ext')
diff --git a/test/test_epoll.rb b/test/test_epoll.rb
index ea9bddf..0acf08d 100644
--- a/test/test_epoll.rb
+++ b/test/test_epoll.rb
@@ -13,6 +13,38 @@ class TestEpoll < Test::Unit::TestCase
     @ep = Epoll.new
   end
 
+  def test_fork_safe
+    tmp = []
+    @ep.add @rd, Epoll::IN
+    pid = fork do
+      @ep.wait(nil, 100) { |flags,obj| tmp << [ flags, obj ] }
+      exit!(tmp.empty?)
+    end
+    @wr.syswrite "HI"
+    _, status = Process.waitpid2(pid)
+    assert status.success?
+    @ep.wait(nil, 0) { |flags,obj| tmp << [ flags, obj ] }
+    assert_equal [[Epoll::IN, @rd]], tmp
+  end
+
+  def test_after_fork_usability
+    fork { @ep.add(@rd, Epoll::IN); exit!(0) }
+    fork { @ep.set(@rd, Epoll::IN); exit!(0) }
+    fork { @ep.to_io; exit!(0) }
+    fork { @ep.close; exit!(0) }
+    fork { @ep.closed?; exit!(0) }
+    fork {
+      begin
+        @ep.del(@rd)
+      rescue Errno::ENOENT
+        exit!(0)
+      end
+      exit!(1)
+    }
+    res = Process.waitall
+    res.each { |(pid,status)| assert status.success? }
+  end
+
   def test_tcp_connect_nonblock_edge
     epflags = Epoll::OUT | Epoll::ET
     host = '127.0.0.1'