about summary refs log tree commit homepage
diff options
context:
space:
mode:
-rw-r--r--ext/kgio/extconf.rb3
-rw-r--r--ext/kgio/read_write.c306
-rw-r--r--test/lib_read_write.rb128
3 files changed, 437 insertions, 0 deletions
diff --git a/ext/kgio/extconf.rb b/ext/kgio/extconf.rb
index f6bd0cc..5fb15ac 100644
--- a/ext/kgio/extconf.rb
+++ b/ext/kgio/extconf.rb
@@ -23,6 +23,8 @@ have_type("struct sockaddr_storage", %w(sys/types.h sys/socket.h)) or
 have_func('accept4', %w(sys/socket.h))
 have_header("sys/select.h")
 
+have_func("writev", "sys/uio.h")
+
 if have_header('ruby/io.h')
   rubyio = %w(ruby.h ruby/io.h)
   have_struct_member("rb_io_t", "fd", rubyio)
@@ -50,5 +52,6 @@ have_func('rb_str_set_len')
 have_func('rb_time_interval')
 have_func('rb_wait_for_single_fd')
 have_func('rb_str_subseq')
+have_func('rb_ary_subseq')
 
 create_makefile('kgio_ext')
diff --git a/ext/kgio/read_write.c b/ext/kgio/read_write.c
index 9924743..cbc4799 100644
--- a/ext/kgio/read_write.c
+++ b/ext/kgio/read_write.c
@@ -1,6 +1,18 @@
 #include "kgio.h"
 #include "my_fileno.h"
 #include "nonblock.h"
+#ifdef HAVE_WRITEV
+#  include <sys/uio.h>
+#  define USE_WRITEV 1
+#else
+#  define USE_WRITEV 0
+static ssize_t assert_writev(int fd, void* iov, int len)
+{
+        assert(0 && "you should not try to call writev");
+        return -1;
+}
+#  define writev assert_writev
+#endif
 static VALUE sym_wait_readable, sym_wait_writable;
 static VALUE eErrno_EPIPE, eErrno_ECONNRESET;
 static ID id_set_backtrace;
@@ -8,6 +20,15 @@ static ID id_set_backtrace;
 #define rb_str_subseq rb_str_substr
 #endif
 
+#ifndef HAVE_RB_ARY_SUBSEQ
+static inline VALUE my_ary_subseq(VALUE ary, long idx, long len)
+{
+        VALUE args[2] = {LONG2FIX(idx), LONG2FIX(len)};
+        return rb_ary_aref(2, args, ary);
+}
+#define rb_ary_subseq my_ary_subseq
+#endif
+
 /*
  * we know MSG_DONTWAIT works properly on all stream sockets under Linux
  * we can define this macro for other platforms as people care and
@@ -406,6 +427,254 @@ static VALUE kgio_trywrite(VALUE io, VALUE str)
         return my_write(io, str, 0);
 }
 
+#ifndef HAVE_WRITEV
+#define iovec my_iovec
+struct my_iovec {
+        void  *iov_base;
+        size_t iov_len;
+};
+#endif
+
+/* tests for choosing following constants were done on Linux 3.0 x86_64
+ * (Ubuntu 12.04) Core i3 i3-2330M slowed to 1600MHz
+ * testing script https://gist.github.com/2850641
+ * fill free to make more thorough testing and choose better value
+ */
+
+/* test shows that its meaningless to set WRITEV_MEMLIMIT more that 1M
+ * even when tcp_wmem set to relatively high value (2M) (in fact, it becomes
+ * even slower). 512K performs a bit better in average case. */
+#define WRITEV_MEMLIMIT (512*1024)
+/* same test shows that custom_writev is faster than glibc writev when
+ * average string is smaller than ~500 bytes and slower when average strings
+ * is greater then ~600 bytes. 512 bytes were choosen cause current compilers
+ * turns x/512 into x>>9 */
+#define WRITEV_IMPL_THRESHOLD 512
+
+static unsigned int iov_max = 1024; /* this could be overriden in init */
+
+struct io_args_v {
+        VALUE io;
+        VALUE buf;
+        VALUE vec_buf;
+        struct iovec *vec;
+        unsigned long iov_cnt;
+        size_t batch_len;
+        int something_written;
+        int fd;
+};
+
+static ssize_t custom_writev(int fd, const struct iovec *vec, unsigned int iov_cnt, size_t total_len)
+{
+        unsigned int i;
+        ssize_t result;
+        char *buf, *curbuf;
+        const struct iovec *curvec = vec;
+
+        /* we do not want to use ruby's xmalloc because
+         * it can fire GC, and we'll free buffer shortly anyway */
+        curbuf = buf = malloc(total_len);
+        if (buf == NULL) return -1;
+
+        for (i = 0; i < iov_cnt; i++, curvec++) {
+                memcpy(curbuf, curvec->iov_base, curvec->iov_len);
+                curbuf += curvec->iov_len;
+        }
+
+        result = write(fd, buf, total_len);
+
+        /* well, it seems that `free` could not change errno
+         * but lets save it anyway */
+        i = errno;
+        free(buf);
+        errno = i;
+
+        return result;
+}
+
+static void prepare_writev(struct io_args_v *a, VALUE io, VALUE ary)
+{
+        long vec_cnt;
+        a->io = io;
+        a->fd = my_fileno(io);
+        a->something_written = 0;
+
+        if (TYPE(ary) == T_ARRAY)
+                /* rb_ary_subseq will not copy array unless it modified */
+                a->buf = rb_ary_subseq(ary, 0, RARRAY_LEN(ary));
+        else
+                a->buf = rb_Array(ary);
+
+        a->vec_buf = rb_str_new(0, 0);
+        a->vec = NULL;
+}
+
+static void fill_iovec(struct io_args_v *a)
+{
+        unsigned long i;
+        struct iovec *curvec;
+
+        a->iov_cnt = RARRAY_LEN(a->buf);
+        a->batch_len = 0;
+        if (a->iov_cnt == 0) return;
+        if (a->iov_cnt > iov_max) a->iov_cnt = iov_max;
+        rb_str_resize(a->vec_buf, sizeof(struct iovec) * a->iov_cnt);
+        curvec = a->vec = (struct iovec*)RSTRING_PTR(a->vec_buf);
+
+        for (i=0; i < a->iov_cnt; i++, curvec++) {
+                /* rb_ary_store could reallocate array,
+                 * so that ought to use RARRAY_PTR */
+                VALUE str = RARRAY_PTR(a->buf)[i];
+                long str_len, next_len;
+
+                if (TYPE(str) != T_STRING) {
+                        str = rb_obj_as_string(str);
+                        rb_ary_store(a->buf, i, str);
+                }
+
+                str_len = RSTRING_LEN(str);
+
+                /* lets limit total memory to write,
+                 * but always take first string */
+                next_len = a->batch_len + str_len;
+                if (i && next_len > WRITEV_MEMLIMIT) {
+                        a->iov_cnt = i;
+                        break;
+                }
+                a->batch_len = next_len;
+
+                curvec->iov_base = RSTRING_PTR(str);
+                curvec->iov_len = str_len;
+        }
+}
+
+static long trim_writev_buffer(struct io_args_v *a, long n)
+{
+        long i;
+        long ary_len = RARRAY_LEN(a->buf);
+        VALUE *elem = RARRAY_PTR(a->buf);
+
+        if (n == (long)a->batch_len) {
+                i = a->iov_cnt;
+                n = 0;
+        } else {
+                for (i = 0; n && i < ary_len; i++, elem++) {
+                        n -= RSTRING_LEN(*elem);
+                        if (n < 0) break;
+                }
+        }
+
+        /* all done */
+        if (i == ary_len) {
+                assert(n == 0 && "writev system call is broken");
+                a->buf = Qnil;
+                return 0;
+        }
+
+        /* partially done, remove fully-written buffers */
+        if (i > 0)
+                a->buf = rb_ary_subseq(a->buf, i, ary_len - i);
+
+        /* setup+replace partially written buffer */
+        if (n < 0) {
+                VALUE str = RARRAY_PTR(a->buf)[0];
+                long str_len = RSTRING_LEN(str);
+                str = rb_str_subseq(str, str_len + n, -n);
+                rb_ary_store(a->buf, 0, str);
+        }
+        return RARRAY_LEN(a->buf);
+}
+
+static int writev_check(struct io_args_v *a, long n, const char *msg, int io_wait)
+{
+        if (n >= 0) {
+                if (n > 0) a->something_written = 1;
+                return trim_writev_buffer(a, n);
+        } else if (n == -1) {
+                if (errno == EINTR) {
+                        a->fd = my_fileno(a->io);
+                        return -1;
+                }
+                if (errno == EAGAIN) {
+                        if (io_wait) {
+                                (void)kgio_call_wait_writable(a->io);
+                                return -1;
+                        } else if (!a->something_written) {
+                                a->buf = sym_wait_writable;
+                        }
+                        return 0;
+                }
+                wr_sys_fail(msg);
+        }
+        return 0;
+}
+
+static VALUE my_writev(VALUE io, VALUE str, int io_wait)
+{
+        struct io_args_v a;
+        long n;
+
+        prepare_writev(&a, io, str);
+        set_nonblocking(a.fd);
+
+        do {
+                fill_iovec(&a);
+                if (a.iov_cnt == 0)
+                        n = 0;
+                else if (a.iov_cnt == 1)
+                        n = (long)write(a.fd, a.vec[0].iov_base, a.vec[0].iov_len);
+                /* for big strings use library function */
+                else if (USE_WRITEV && a.batch_len / WRITEV_IMPL_THRESHOLD > a.iov_cnt)
+                        n = (long)writev(a.fd, a.vec, a.iov_cnt);
+                else
+                        n = (long)custom_writev(a.fd, a.vec, a.iov_cnt, a.batch_len);
+        } while (writev_check(&a, n, "writev", io_wait) != 0);
+        rb_str_resize(a.vec_buf, 0);
+
+        if (TYPE(a.buf) != T_SYMBOL)
+                kgio_autopush_write(io);
+        return a.buf;
+}
+
+/*
+ * call-seq:
+ *
+ *        io.kgio_writev(array)        -> nil
+ *
+ * Returns nil when the write completes.
+ *
+ * This may block and call any method defined to +kgio_wait_writable+
+ * for the class.
+ *
+ * Note: it uses +Array()+ semantic for converting argument, so that
+ * it will succeed if you pass something else.
+ */
+static VALUE kgio_writev(VALUE io, VALUE ary)
+{
+        return my_writev(io, ary, 1);
+}
+
+/*
+ * call-seq:
+ *
+ *        io.kgio_trywritev(array)        -> nil, Array or :wait_writable
+ *
+ * Returns nil if the write was completed in full.
+ *
+ * Returns an Array of strings containing the unwritten portion
+ * if EAGAIN was encountered, but some portion was successfully written.
+ *
+ * Returns :wait_writable if EAGAIN is encountered and nothing
+ * was written.
+ *
+ * Note: it uses +Array()+ semantic for converting argument, so that
+ * it will succeed if you pass something else.
+ */
+static VALUE kgio_trywritev(VALUE io, VALUE ary)
+{
+        return my_writev(io, ary, 0);
+}
+
 #ifdef USE_MSG_DONTWAIT
 /*
  * This method behaves like Kgio::PipeMethods#kgio_write, except
@@ -489,6 +758,26 @@ static VALUE s_trywrite(VALUE mod, VALUE io, VALUE str)
         return my_write(io, str, 0);
 }
 
+/*
+ * call-seq:
+ *
+ *        Kgio.trywritev(io, array)    -> nil, Array or :wait_writable
+ *
+ * Returns nil if the write was completed in full.
+ *
+ * Returns a Array of strings containing the unwritten portion if EAGAIN
+ * was encountered, but some portion was successfully written.
+ *
+ * Returns :wait_writable if EAGAIN is encountered and nothing
+ * was written.
+ *
+ * Maybe used in place of PipeMethods#kgio_trywritev for non-Kgio objects
+ */
+static VALUE s_trywritev(VALUE mod, VALUE io, VALUE ary)
+{
+        return kgio_trywritev(io, ary);
+}
+
 void init_kgio_read_write(void)
 {
         VALUE mPipeMethods, mSocketMethods;
@@ -500,6 +789,7 @@ void init_kgio_read_write(void)
 
         rb_define_singleton_method(mKgio, "tryread", s_tryread, -1);
         rb_define_singleton_method(mKgio, "trywrite", s_trywrite, 2);
+        rb_define_singleton_method(mKgio, "trywritev", s_trywritev, 2);
         rb_define_singleton_method(mKgio, "trypeek", s_trypeek, -1);
 
         /*
@@ -513,8 +803,10 @@ void init_kgio_read_write(void)
         rb_define_method(mPipeMethods, "kgio_read", kgio_read, -1);
         rb_define_method(mPipeMethods, "kgio_read!", kgio_read_bang, -1);
         rb_define_method(mPipeMethods, "kgio_write", kgio_write, 1);
+        rb_define_method(mPipeMethods, "kgio_writev", kgio_writev, 1);
         rb_define_method(mPipeMethods, "kgio_tryread", kgio_tryread, -1);
         rb_define_method(mPipeMethods, "kgio_trywrite", kgio_trywrite, 1);
+        rb_define_method(mPipeMethods, "kgio_trywritev", kgio_trywritev, 1);
 
         /*
          * Document-module: Kgio::SocketMethods
@@ -527,8 +819,10 @@ void init_kgio_read_write(void)
         rb_define_method(mSocketMethods, "kgio_read", kgio_recv, -1);
         rb_define_method(mSocketMethods, "kgio_read!", kgio_recv_bang, -1);
         rb_define_method(mSocketMethods, "kgio_write", kgio_send, 1);
+        rb_define_method(mSocketMethods, "kgio_writev", kgio_writev, 1);
         rb_define_method(mSocketMethods, "kgio_tryread", kgio_tryrecv, -1);
         rb_define_method(mSocketMethods, "kgio_trywrite", kgio_trysend, 1);
+        rb_define_method(mSocketMethods, "kgio_trywritev", kgio_trywritev, 1);
         rb_define_method(mSocketMethods, "kgio_trypeek", kgio_trypeek, -1);
         rb_define_method(mSocketMethods, "kgio_peek", kgio_peek, -1);
 
@@ -544,4 +838,16 @@ void init_kgio_read_write(void)
         eErrno_ECONNRESET = rb_const_get(rb_mErrno, rb_intern("ECONNRESET"));
         rb_include_module(mPipeMethods, mWaiters);
         rb_include_module(mSocketMethods, mWaiters);
+
+#ifdef HAVE_WRITEV
+        {
+#  ifdef IOV_MAX
+                unsigned int sys_iov_max = IOV_MAX;
+#  else
+                unsigned int sys_iov_max = sysconf(_SC_IOV_MAX);
+#  endif
+                if (sys_iov_max < iov_max)
+                        iov_max = sys_iov_max;
+        }
+#endif
 }
diff --git a/test/lib_read_write.rb b/test/lib_read_write.rb
index 6f345cb..04d6fc6 100644
--- a/test/lib_read_write.rb
+++ b/test/lib_read_write.rb
@@ -21,6 +21,14 @@ module LibReadWriteTest
     assert_nil @wr.kgio_trywrite("")
   end
 
+  def test_writev_empty
+    assert_nil @wr.kgio_writev([])
+  end
+
+  def test_trywritev_empty
+    assert_nil @wr.kgio_trywritev([])
+  end
+
   def test_read_zero
     assert_equal "", @rd.kgio_read(0)
     buf = "foo"
@@ -116,6 +124,28 @@ module LibReadWriteTest
     assert false, "should never get here (line:#{__LINE__})"
   end
 
+  def test_writev_closed
+    @rd.close
+    begin
+      loop { @wr.kgio_writev ["HI"] }
+    rescue Errno::EPIPE, Errno::ECONNRESET => e
+      assert_equal [], e.backtrace
+      return
+    end
+    assert false, "should never get here (line:#{__LINE__})"
+  end
+
+  def test_trywritev_closed
+    @rd.close
+    begin
+      loop { @wr.kgio_trywritev ["HI"] }
+    rescue Errno::EPIPE, Errno::ECONNRESET => e
+      assert_equal [], e.backtrace
+      return
+    end
+    assert false, "should never get here (line:#{__LINE__})"
+  end
+
   def test_trywrite_full
     buf = "\302\251" * 1024 * 1024
     buf2 = ""
@@ -153,6 +183,43 @@ module LibReadWriteTest
     assert_equal '8ff79d8115f9fe38d18be858c66aa08a1cc27a66', t.value
   end
 
+  def test_trywritev_full
+    buf = ["\302\251" * 128] * 8 * 1024
+    buf2 = ""
+    dig = Digest::SHA1.new
+    t = Thread.new do
+      sleep 1
+      nr = 0
+      begin
+        dig.update(@rd.readpartial(4096, buf2))
+        nr += buf2.size
+      rescue EOFError
+        break
+      rescue => e
+      end while true
+      dig.hexdigest
+    end
+    50.times do
+      wr = buf
+      begin
+        rv = @wr.kgio_trywritev(wr)
+        case rv
+        when Array
+          wr = rv
+        when :wait_readable
+          assert false, "should never get here line=#{__LINE__}"
+        when :wait_writable
+          IO.select(nil, [ @wr ])
+        else
+          wr = false
+        end
+      end while wr
+    end
+    @wr.close
+    t.join
+    assert_equal '8ff79d8115f9fe38d18be858c66aa08a1cc27a66', t.value
+  end
+
   def test_write_conv
     assert_equal nil, @wr.kgio_write(10)
     assert_equal "10", @rd.kgio_read(2)
@@ -214,6 +281,19 @@ module LibReadWriteTest
     tmp.each { |count| assert_equal nil, count }
   end
 
+  def test_trywritev_return_wait_writable
+    tmp = []
+    tmp << @wr.kgio_trywritev(["HI"]) until tmp[-1] == :wait_writable
+    assert :wait_writable === tmp[-1]
+    assert(!(:wait_readable === tmp[-1]))
+    assert_equal :wait_writable, tmp.pop
+    assert tmp.size > 0
+    penultimate = tmp.pop
+    assert(penultimate == "I" || penultimate == nil)
+    assert tmp.size > 0
+    tmp.each { |count| assert_equal nil, count }
+  end
+
   def test_tryread_extra_buf_eagain_clears_buffer
     tmp = "hello world"
     rv = @rd.kgio_tryread(2, tmp)
@@ -248,6 +328,36 @@ module LibReadWriteTest
     assert_equal buf, readed
   end
 
+  def test_monster_trywritev
+    buf, start = [], 0
+    while start < RANDOM_BLOB.size
+      s = RANDOM_BLOB[start, 1000]
+      start += s.size
+      buf << s
+    end
+    rv = @wr.kgio_trywritev(buf)
+    assert_kind_of Array, rv
+    rv = rv.join
+    assert rv.size < RANDOM_BLOB.size
+    @rd.nonblock = false
+    assert_equal(RANDOM_BLOB, @rd.read(RANDOM_BLOB.size - rv.size) + rv)
+  end
+
+  def test_monster_writev
+    buf, start = [], 0
+    while start < RANDOM_BLOB.size
+      s = RANDOM_BLOB[start, 10000]
+      start += s.size
+      buf << s
+    end
+    thr = Thread.new { @wr.kgio_writev(buf) }
+    @rd.nonblock = false
+    readed = @rd.read(RANDOM_BLOB.size)
+    thr.join
+    assert_nil thr.value
+    assert_equal RANDOM_BLOB, readed
+  end
+
   def test_monster_write_wait_writable
     @wr.instance_variable_set :@nr, 0
     def @wr.kgio_wait_writable
@@ -256,6 +366,7 @@ module LibReadWriteTest
     end
     buf = "." * 1024 * 1024 * 10
     thr = Thread.new { @wr.kgio_write(buf) }
+    Thread.pass until thr.stop?
     readed = @rd.read(buf.size)
     thr.join
     assert_nil thr.value
@@ -263,6 +374,23 @@ module LibReadWriteTest
     assert @wr.instance_variable_get(:@nr) > 0
   end
 
+  def test_monster_writev_wait_writable
+    @wr.instance_variable_set :@nr, 0
+    def @wr.kgio_wait_writable
+      @nr += 1
+      IO.select(nil, [self])
+    end
+    buf = ["." * 1024] * 1024 * 10
+    buf_size = buf.inject(0){|c, s| c + s.size}
+    thr = Thread.new { @wr.kgio_writev(buf) }
+    Thread.pass until thr.stop?
+    readed = @rd.read(buf_size)
+    thr.join
+    assert_nil thr.value
+    assert_equal buf.join, readed
+    assert @wr.instance_variable_get(:@nr) > 0
+  end
+
   def test_wait_readable_ruby_default
     elapsed = 0
     foo = nil