leevis.com icon indicating copy to clipboard operation
leevis.com copied to clipboard

ngx_http_limit_req_module 代码分析

Open vislee opened this issue 8 years ago • 0 comments

nginx中该模块用来限制请求速度。 例如限制单ip每秒1个请求,超过5个则返回错误码,没超过则降低处理请求的速度。 配置如下:

http {
    limit_req_zone $binary_remote_addr zone=one:10m rate=3r/s;
    limit_req_zone $host zone=two:10m rate=2r/s;
    limit_req_zone ${host}_$binary_remote_addr zone=three:10m rate=1r/s;
    ...
    server {
        ...
        location /search/ {
            limit_req zone=one burst=5;
            limit_req zone=two burst=3;
            limit_req zone=three nodelay;
        }

代码分析

启动阶段

nginx几乎所有模块都定义了一个ngx_module_t结构,该结构定义了存储配置文件的内存结构,以及解析配置文件的指令,添加handler的处理函数。

  • 首先调用ngx_http_limit_req_create_conf函数存储配置文件结构体。
  • 接着会调用ngx_http_limit_req_commands数组中的指令解析配置文件。
    • limit_req_zone: 配置在http{}中,指定根据某个变量限制的单位时间的请求次数,保存请求次数的共享内存的大小,以及共享内存的名称。指令的解析会调用ngx_http_limit_req_zone函数。
    • limit_req: 配置在localtion{}中,指定用那个策略限制请求速度。burst表示超过这个值请求将返回错误码。指令的解析会调用ngx_http_limit_req函数。
  • 最后会调用ngx_http_limit_req_init函数注册handler(ngx_http_limit_req_handler)。

结构体定义:

typedef struct {
    ngx_rbtree_t                  rbtree;
    ngx_rbtree_node_t             sentinel;
    ngx_queue_t                   queue;
} ngx_http_limit_req_shctx_t;  // 一颗带lru的红黑树

typedef struct {
    ngx_http_limit_req_shctx_t  *sh;
    ngx_slab_pool_t             *shpool;  // 共享内存使用slab管理。
    /* integer value, 1 corresponds to 0.001 r/s */
    ngx_uint_t                   rate;
    ngx_http_complex_value_t     key;
    ngx_http_limit_req_node_t   *node;
} ngx_http_limit_req_ctx_t;


typedef struct {
    ngx_shm_zone_t              *shm_zone;  // 定义限速的共享内存
    /* integer value, 1 corresponds to 0.001 r/s */
    ngx_uint_t                   burst;
    ngx_uint_t                   nodelay; /* unsigned  nodelay:1 */
} ngx_http_limit_req_limit_t;

函数定义:

初始化共享内存

static char *
ngx_http_limit_req_zone(ngx_conf_t *cf, ngx_command_t *cmd, void *conf)
{
    u_char                            *p;
    size_t                             len;
    ssize_t                            size;
    ngx_str_t                         *value, name, s;
    ngx_int_t                          rate, scale;
    ngx_uint_t                         i;
    ngx_shm_zone_t                    *shm_zone;
    ngx_http_limit_req_ctx_t          *ctx;
    ngx_http_compile_complex_value_t   ccv;

    value = cf->args->elts;

    ctx = ngx_pcalloc(cf->pool, sizeof(ngx_http_limit_req_ctx_t));
    if (ctx == NULL) {
        return NGX_CONF_ERROR;
    }

    ngx_memzero(&ccv, sizeof(ngx_http_compile_complex_value_t));

    ccv.cf = cf;
    ccv.value = &value[1];
    ccv.complex_value = &ctx->key;

    // 编译脚本(获取变量)
    if (ngx_http_compile_complex_value(&ccv) != NGX_OK) {
        return NGX_CONF_ERROR;
    }

    size = 0;
    rate = 1;
    scale = 1;
    name.len = 0;

    for (i = 2; i < cf->args->nelts; i++) {

        if (ngx_strncmp(value[i].data, "zone=", 5) == 0) {

            name.data = value[i].data + 5;

            p = (u_char *) ngx_strchr(name.data, ':');

            if (p == NULL) {
                ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                                   "invalid zone size \"%V\"", &value[i]);
                return NGX_CONF_ERROR;
            }

            name.len = p - name.data;

            s.data = p + 1;
            s.len = value[i].data + value[i].len - s.data;

            size = ngx_parse_size(&s);

            if (size == NGX_ERROR) {
                ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                                   "invalid zone size \"%V\"", &value[i]);
                return NGX_CONF_ERROR;
            }

            if (size < (ssize_t) (8 * ngx_pagesize)) {
                ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                                   "zone \"%V\" is too small", &value[i]);
                return NGX_CONF_ERROR;
            }

            continue;
        }

        if (ngx_strncmp(value[i].data, "rate=", 5) == 0) {

            len = value[i].len;
            p = value[i].data + len - 3;

            if (ngx_strncmp(p, "r/s", 3) == 0) {
                scale = 1;
                len -= 3;

            } else if (ngx_strncmp(p, "r/m", 3) == 0) {
                scale = 60;
                len -= 3;
            }

            rate = ngx_atoi(value[i].data + 5, len - 5);
            if (rate <= 0) {
                ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                                   "invalid rate \"%V\"", &value[i]);
                return NGX_CONF_ERROR;
            }

            continue;
        }

        ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                           "invalid parameter \"%V\"", &value[i]);
        return NGX_CONF_ERROR;
    }

    if (name.len == 0) {
        ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                           "\"%V\" must have \"zone\" parameter",
                           &cmd->name);
        return NGX_CONF_ERROR;
    }

    // 请求速率扩大1000倍。单位是:请求数/秒
    ctx->rate = rate * 1000 / scale;

    // 向nginx添加一个共享内存申请,名称为name,大小为size,属于limit_req 模块。
    shm_zone = ngx_shared_memory_add(cf, &name, size,
                                     &ngx_http_limit_req_module);
    if (shm_zone == NULL) {
        return NGX_CONF_ERROR;
    }

    if (shm_zone->data) {
        ctx = shm_zone->data;

        ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                           "%V \"%V\" is already bound to key \"%V\"",
                           &cmd->name, &name, &ctx->key.value);
        return NGX_CONF_ERROR;
    }

    // 申请成功后,调用该回调函数初始化。
    shm_zone->init = ngx_http_limit_req_init_zone;
    shm_zone->data = ctx;

    return NGX_CONF_OK;
}

共享内存初始化回调函数

static ngx_int_t
ngx_http_limit_req_init_zone(ngx_shm_zone_t *shm_zone, void *data)
{
    ngx_http_limit_req_ctx_t  *octx = data;

    size_t                     len;
    ngx_http_limit_req_ctx_t  *ctx;

    ctx = shm_zone->data;

    if (octx) {
        if (ctx->key.value.len != octx->key.value.len
            || ngx_strncmp(ctx->key.value.data, octx->key.value.data,
                           ctx->key.value.len)
               != 0)
        {
            ngx_log_error(NGX_LOG_EMERG, shm_zone->shm.log, 0,
                          "limit_req \"%V\" uses the \"%V\" key "
                          "while previously it used the \"%V\" key",
                          &shm_zone->shm.name, &ctx->key.value,
                          &octx->key.value);
            return NGX_ERROR;
        }

        ctx->sh = octx->sh;
        ctx->shpool = octx->shpool;

        return NGX_OK;
    }

    // slab 内存分配算法管理该共享内存
    ctx->shpool = (ngx_slab_pool_t *) shm_zone->shm.addr;

    if (shm_zone->shm.exists) {
        ctx->sh = ctx->shpool->data;

        return NGX_OK;
    }

    // 分配ngx_http_limit_req_shctx_t结构
    ctx->sh = ngx_slab_alloc(ctx->shpool, sizeof(ngx_http_limit_req_shctx_t));
    if (ctx->sh == NULL) {
        return NGX_ERROR;
    }

    ctx->shpool->data = ctx->sh;
    // 初始化红黑树,根节点为ctx->sh->rbtree,哨兵节点为ctx->sh->sentinel 
    // 插入函数为ngx_http_limit_req_rbtree_insert_value
    ngx_rbtree_init(&ctx->sh->rbtree, &ctx->sh->sentinel,
                    ngx_http_limit_req_rbtree_insert_value);

    // 初始化队列,指向红黑树节点。用来标示最近被访问的红黑树节点。
    // 该队列和红黑树形成一个lru的数据结构
    ngx_queue_init(&ctx->sh->queue);

    len = sizeof(" in limit_req zone \"\"") + shm_zone->shm.name.len;

    ctx->shpool->log_ctx = ngx_slab_alloc(ctx->shpool, len);
    if (ctx->shpool->log_ctx == NULL) {
        return NGX_ERROR;
    }

    ngx_sprintf(ctx->shpool->log_ctx, " in limit_req zone \"%V\"%Z",
                &shm_zone->shm.name);

    ctx->shpool->log_nomem = 0;

    return NGX_OK;
}

limit_req指令处理函数

static char *
ngx_http_limit_req(ngx_conf_t *cf, ngx_command_t *cmd, void *conf)
{
    ngx_http_limit_req_conf_t  *lrcf = conf;

    ngx_int_t                    burst;
    ngx_str_t                   *value, s;
    ngx_uint_t                   i, nodelay;
    ngx_shm_zone_t              *shm_zone;
    ngx_http_limit_req_limit_t  *limit, *limits;

    value = cf->args->elts;

    shm_zone = NULL;
    burst = 0;
    nodelay = 0;

    for (i = 1; i < cf->args->nelts; i++) {

        if (ngx_strncmp(value[i].data, "zone=", 5) == 0) {

            s.len = value[i].len - 5;
            s.data = value[i].data + 5;

            // 根据配置的名称查找共享内存结构
            shm_zone = ngx_shared_memory_add(cf, &s, 0,
                                             &ngx_http_limit_req_module);
            if (shm_zone == NULL) {
                return NGX_CONF_ERROR;
            }

            continue;
        }

        if (ngx_strncmp(value[i].data, "burst=", 6) == 0) {

            burst = ngx_atoi(value[i].data + 6, value[i].len - 6);
            if (burst <= 0) {
                ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                                   "invalid burst rate \"%V\"", &value[i]);
                return NGX_CONF_ERROR;
            }

            continue;
        }

        if (ngx_strcmp(value[i].data, "nodelay") == 0) {
            nodelay = 1;
            continue;
        }

        ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                           "invalid parameter \"%V\"", &value[i]);
        return NGX_CONF_ERROR;
    }

    if (shm_zone == NULL) {
        ngx_conf_log_error(NGX_LOG_EMERG, cf, 0,
                           "\"%V\" must have \"zone\" parameter",
                           &cmd->name);
        return NGX_CONF_ERROR;
    }

    limits = lrcf->limits.elts;

    if (limits == NULL) {
        if (ngx_array_init(&lrcf->limits, cf->pool, 1,
                           sizeof(ngx_http_limit_req_limit_t))
            != NGX_OK)
        {
            return NGX_CONF_ERROR;
        }
    }

    for (i = 0; i < lrcf->limits.nelts; i++) {
        if (shm_zone == limits[i].shm_zone) {
            return "is duplicate";
        }
    }

    // 把查找到的共享内存结构保存到limits数组中,一起保存的还有burst和nodelay配置参数。
    limit = ngx_array_push(&lrcf->limits);
    if (limit == NULL) {
        return NGX_CONF_ERROR;
    }
    // 共享内存结构,该结构data指针指向了ngx_http_limit_req_ctx_t结构,保存了限制的指定时间内的请求数
    limit->shm_zone = shm_zone;
    // burst 也扩大1000倍
    limit->burst = burst * 1000;
    limit->nodelay = nodelay;

    return NGX_CONF_OK;
}

请求执行

当请求到达,解析完header后,通过http的11阶段处理。当执行到NGX_HTTP_PREACCESS_PHASE阶段时,会被ngx_http_limit_req_init函数注册的handler处理,该handler为ngx_http_limit_req_handler函数。

static ngx_int_t
ngx_http_limit_req_handler(ngx_http_request_t *r)
{
    uint32_t                     hash;
    ngx_str_t                    key;
    ngx_int_t                    rc;
    ngx_uint_t                   n, excess;
    ngx_msec_t                   delay;
    ngx_http_limit_req_ctx_t    *ctx;
    ngx_http_limit_req_conf_t   *lrcf;
    ngx_http_limit_req_limit_t  *limit, *limits;

    if (r->main->limit_req_set) {
        return NGX_DECLINED;
    }

    // 该模块的配置
    lrcf = ngx_http_get_module_loc_conf(r, ngx_http_limit_req_module);
    limits = lrcf->limits.elts;

    excess = 0;

    rc = NGX_DECLINED;

#if (NGX_SUPPRESS_WARN)
    limit = NULL;
#endif

    // 循环所有规则
    for (n = 0; n < lrcf->limits.nelts; n++) {

        limit = &limits[n];

        ctx = limit->shm_zone->data;

        // 根据变量从请求头中查找对应的值
        if (ngx_http_complex_value(r, &ctx->key, &key) != NGX_OK) {
            return NGX_HTTP_INTERNAL_SERVER_ERROR;
        }

        if (key.len == 0) {
            continue;
        }

        if (key.len > 65535) {
            ngx_log_error(NGX_LOG_ERR, r->connection->log, 0,
                          "the value of the \"%V\" key "
                          "is more than 65535 bytes: \"%V\"",
                          &ctx->key.value, &key);
            continue;
        }

        hash = ngx_crc32_short(key.data, key.len);

        ngx_shmtx_lock(&ctx->shpool->mutex);

        // 查找lru的红黑树,是否超配。下面会详细说明。
        rc = ngx_http_limit_req_lookup(limit, hash, &key, &excess,
                                       (n == lrcf->limits.nelts - 1));

        ngx_shmtx_unlock(&ctx->shpool->mutex);

        ngx_log_debug4(NGX_LOG_DEBUG_HTTP, r->connection->log, 0,
                       "limit_req[%ui]: %i %ui.%03ui",
                       n, rc, excess / 1000, excess % 1000);
        // 本规则已超配或出错,不用检测下一条规则
        if (rc != NGX_AGAIN) {
            break;
        }
    }

    // 没有限速规则,返回NGX_DECLINED。ngx框架会调用本阶段下一个handler。
    if (rc == NGX_DECLINED) {
        return NGX_DECLINED;
    }
    // 标识已经被限速处理了
    r->main->limit_req_set = 1;

    if (rc == NGX_BUSY || rc == NGX_ERROR) {

        if (rc == NGX_BUSY) {
            ngx_log_error(lrcf->limit_log_level, r->connection->log, 0,
                          "limiting requests, excess: %ui.%03ui by zone \"%V\"",
                          excess / 1000, excess % 1000,
                          &limit->shm_zone->shm.name);
        }

        while (n--) {
            ctx = limits[n].shm_zone->data;

            if (ctx->node == NULL) {
                continue;
            }

            ngx_shmtx_lock(&ctx->shpool->mutex);

            ctx->node->count--;

            ngx_shmtx_unlock(&ctx->shpool->mutex);

            ctx->node = NULL;
        }

        return lrcf->status_code;
    }

    /* rc == NGX_AGAIN || rc == NGX_OK */
    // 最后一条限速规则返回的
    if (rc == NGX_AGAIN) {
        excess = 0;
    }

    // delay 大于0则降低接收请求的速率
    delay = ngx_http_limit_req_account(limits, n, &excess, &limit);

    if (!delay) {
        return NGX_DECLINED;
    }

    ngx_log_error(lrcf->delay_log_level, r->connection->log, 0,
                  "delaying request, excess: %ui.%03ui, by zone \"%V\"",
                  excess / 1000, excess % 1000, &limit->shm_zone->shm.name);

    // 可读事件添加到事件处理
    if (ngx_handle_read_event(r->connection->read, 0) != NGX_OK) {
        return NGX_HTTP_INTERNAL_SERVER_ERROR;
    }

    // 该回调函数通过recv的MSG_PEEK参数,读取1个字节的内容
    // 检查返回结果是否出错或连接关闭,并删除可读事件。
    r->read_event_handler = ngx_http_test_reading;
    // 该可写事件的回调函数不会被可写事件调用,只会被超时调用,如果被可写事件调用说明出问题了。
    // 该回调函数如果没有超时,则添加可写回调事件到epoll,返回500错误。
    // 超时回调说明设置的延迟处理请求到时了,则添加可读事件回调到epoll中。设置可读可写事件的回调函数
    // 并调用ngx_http_core_run_phases继续运行11个阶段。
    r->write_event_handler = ngx_http_limit_req_delay;
    // 设置可写事件的超时,其实这里并不是可写事件的工作,只能说是灵活应用了可写事件,
    // 否则就得分配一个事件结构,并设置回调,添加到超时。
    // 因此ngx_http_limit_req_delay回调函数也紧紧是该模块的一个static函数。不能被别的文件使用。
    ngx_add_timer(r->connection->write, delay);

    return NGX_AGAIN;
}
// NGX_BUSY 严重超配,已经超出burst阈值,需要返回错误码
// NGX_OK    本规则超配,需要延迟处理。
// NGX_AGAIN 本规则没有超配,要检测下一个规则。
static ngx_int_t
ngx_http_limit_req_lookup(ngx_http_limit_req_limit_t *limit, ngx_uint_t hash,
    ngx_str_t *key, ngx_uint_t *ep, ngx_uint_t account)
{
    size_t                      size;
    ngx_int_t                   rc, excess;
    ngx_msec_t                  now;
    ngx_msec_int_t              ms;
    ngx_rbtree_node_t          *node, *sentinel;
    ngx_http_limit_req_ctx_t   *ctx;
    ngx_http_limit_req_node_t  *lr;

    now = ngx_current_msec;

    ctx = limit->shm_zone->data;

    node = ctx->sh->rbtree.root;
    sentinel = ctx->sh->rbtree.sentinel;

    while (node != sentinel) {

        if (hash < node->key) {
            node = node->left;
            continue;
        }

        if (hash > node->key) {
            node = node->right;
            continue;
        }

        /* hash == node->key */

        lr = (ngx_http_limit_req_node_t *) &node->color;

        rc = ngx_memn2cmp(key->data, lr->data, key->len, (size_t) lr->len);

        if (rc == 0) {
            // 找到该节点,移动到链表头,lru算法用。
            ngx_queue_remove(&lr->queue);
            ngx_queue_insert_head(&ctx->sh->queue, &lr->queue);

            ms = (ngx_msec_int_t) (now - lr->last);

            // excess 超配的请求
            // 节点原来超配的请求数 - 这段时间可以流过的请求数 + 本次请求
            // rate 被放大了1000倍,除1000是单位换算。请求/秒 转换为 请求/毫秒
            excess = lr->excess - ctx->rate * ngx_abs(ms) / 1000 + 1000;
            // 漏过的请求大于实际的请求数,不用限制。
            if (excess < 0) {
                excess = 0;
            }

            *ep = excess;
            // 超出的请求已经大于burst这个阈值时,请求就会返回指定的错误码
            if ((ngx_uint_t) excess > limit->burst) {
                return NGX_BUSY;
            }

            // 限速规则最后一条规则
            if (account) {
                lr->excess = excess;
                lr->last = now;
                return NGX_OK;
            }

            lr->count++;

            ctx->node = lr;

            return NGX_AGAIN;
        }

        node = (rc < 0) ? node->left : node->right;
    }
    // 没有找到对应值的节点,则分配一个节点
    *ep = 0;

    size = offsetof(ngx_rbtree_node_t, color)
           + offsetof(ngx_http_limit_req_node_t, data)
           + key->len;

    // 最多回收两个大于6秒且没有被引用且没有限速的老节点
    ngx_http_limit_req_expire(ctx, 1);

    node = ngx_slab_alloc_locked(ctx->shpool, size);

    if (node == NULL) {
        // 最多回收两个没有被引用的老节点
        ngx_http_limit_req_expire(ctx, 0);

        node = ngx_slab_alloc_locked(ctx->shpool, size);
        if (node == NULL) {
            ngx_log_error(NGX_LOG_ALERT, ngx_cycle->log, 0,
                          "could not allocate node%s", ctx->shpool->log_ctx);
            return NGX_ERROR;
        }
    }

    node->key = hash;

    lr = (ngx_http_limit_req_node_t *) &node->color;

    lr->len = (u_short) key->len;
    lr->excess = 0;

    ngx_memcpy(lr->data, key->data, key->len);

    ngx_rbtree_insert(&ctx->sh->rbtree, node);

    ngx_queue_insert_head(&ctx->sh->queue, &lr->queue);

    if (account) {
        lr->last = now;
        lr->count = 0;
        return NGX_OK;
    }

    lr->last = 0;
    lr->count = 1;

    ctx->node = lr;

    return NGX_AGAIN;
}
static ngx_msec_t
ngx_http_limit_req_account(ngx_http_limit_req_limit_t *limits, ngx_uint_t n,
    ngx_uint_t *ep, ngx_http_limit_req_limit_t **limit)
{
    ngx_int_t                   excess;
    ngx_msec_t                  now, delay, max_delay;
    ngx_msec_int_t              ms;
    ngx_http_limit_req_ctx_t   *ctx;
    ngx_http_limit_req_node_t  *lr;

    excess = *ep;
    // 没有达到限速处理,或则强制不延迟处理
    if (excess == 0 || (*limit)->nodelay) {
        max_delay = 0;

    } else {
        ctx = (*limit)->shm_zone->data;
        // 超配的请求/速率 == 需要处理超配的时间
        max_delay = excess * 1000 / ctx->rate;
    }

    // 从所有规则中找限速最小的规则,计算需要延迟的时长
    while (n--) {
        ctx = limits[n].shm_zone->data;
        lr = ctx->node;

        if (lr == NULL) {
            continue;
        }

        ngx_shmtx_lock(&ctx->shpool->mutex);

        now = ngx_current_msec;
        ms = (ngx_msec_int_t) (now - lr->last);

        excess = lr->excess - ctx->rate * ngx_abs(ms) / 1000 + 1000;

        if (excess < 0) {
            excess = 0;
        }

        lr->last = now;
        lr->excess = excess;
        lr->count--;

        ngx_shmtx_unlock(&ctx->shpool->mutex);

        ctx->node = NULL;

        if (limits[n].nodelay) {
            continue;
        }

        delay = excess * 1000 / ctx->rate;

        if (delay > max_delay) {
            max_delay = delay;
            *ep = excess;
            *limit = &limits[n];
        }
    }

    return max_delay;
}

// c->write->handler = ngx_http_request_handler;
static void
ngx_http_request_handler(ngx_event_t *ev)
{
    ngx_connection_t    *c;
    ngx_http_request_t  *r;

    c = ev->data;
    r = c->data;

    ngx_http_set_log_request(c->log, r);

    ngx_log_debug2(NGX_LOG_DEBUG_HTTP, c->log, 0,
                   "http run request: \"%V?%V\"", &r->uri, &r->args);
    // 如果是延迟设置超时,清理延迟标识和超时标识
    if (ev->delayed && ev->timedout) {
        ev->delayed = 0;
        ev->timedout = 0;
    }

    if (ev->write) {
        r->write_event_handler(r);

    } else {
        r->read_event_handler(r);
    }

    ngx_http_run_posted_requests(c);
}


// 可写事件回调函数,回调是被设置的延迟超时调用的
static void
ngx_http_limit_req_delay(ngx_http_request_t *r)
{
    ngx_event_t  *wev;

    ngx_log_debug0(NGX_LOG_DEBUG_HTTP, r->connection->log, 0,
                   "limit_req delay");

    wev = r->connection->write;

    if (wev->delayed) {

        if (ngx_handle_write_event(wev, 0) != NGX_OK) {
            ngx_http_finalize_request(r, NGX_HTTP_INTERNAL_SERVER_ERROR);
        }

        return;
    }

    if (ngx_handle_read_event(r->connection->read, 0) != NGX_OK) {
        ngx_http_finalize_request(r, NGX_HTTP_INTERNAL_SERVER_ERROR);
        return;
    }

    r->read_event_handler = ngx_http_block_reading;
    r->write_event_handler = ngx_http_core_run_phases;

    ngx_http_core_run_phases(r);
}

总结

limit_req_zone 设置以某个(某几个)变量的值请求量限速,限制每分钟或每秒钟的请求数。超过以后延迟处理或直接返回错误码。 limit_req 设置具体的请求限速规则,请求量超过配置的限制速度且超过burst则直接返回错误码,没超过burst但是超过限制速度且没有配置nodelay则延迟处理,超过限速配置了nodelay正常处理。如果没有配置burst配置了nodelay也没有意义。

旁路限qps

package main

import (
	"bytes"
	"container/list"
	"encoding/json"
	"errors"
	"fmt"
	"hash/crc32"
	"strconv"
	"strings"
	"sync"
	"time"
	"net"
)

type ccRateLimitOpt struct {
	Host     string               `json:"host"`
	LimitOpt []*ccRateLimitConfig `json:"limit"`
}

type ccRateLimitConfig struct {
	Key    string `json:"key"`
	Val    string `json:"val"`
	Rate   string `json:"rate"`
	Burst  int64  `json:"burst"`
	Action string `json:"action"`
	Cap    uint64 `json:"cap"`
}

type ccRateLimitItem struct {
	Excess int64
	Last   int64
}

type ccRateLimiter struct {
	sync.Mutex

	ConfCrc32 uint32
	Key       []string
	Val       []string
	Rate      int64
	Burst     int64
	Action    int
	Cap       uint64
	Dimension map[string]*ccRateLimitItem
	l         *list.List
	size      uint64
}

func NewRateLimiter(conf *ccRateLimitConfig) (*ccRateLimiter, error) {
	var (
		ccrc32  uint32
		keys    []string
		vals    []string
		rate    int64
		action  int   = 0
		scale   int64 = 1
		rateLen int   = len(conf.Rate)
		err     error
	)

	keys = strings.Split(conf.Key, ",")
	vals = strings.Split(conf.Val, ",")
	if len(keys) != len(vals) {
		return nil, errors.New("Parse key and val error. mismatches.")
	}

	if strings.HasSuffix(conf.Rate, "r/m") {
		scale = 60
		rateLen -= 3
	} else if strings.HasSuffix(conf.Rate, "r/s") {
		scale = 1
		rateLen -= 3
	}

	rate, err = strconv.ParseInt(conf.Rate[0:rateLen], 10, 64)
	if err != nil {
		return nil, fmt.Errorf("Parse rate error. %s", err.Error())
	}
	rate = rate * 1000 / scale

	tmp, err := json.Marshal(&conf)
	if err != nil {
		return nil, fmt.Errorf("Marshal error. %s", err.Error())
	}
	ccrc32 = crc32.ChecksumIEEE(tmp)

	return &ccRateLimiter{ConfCrc32: ccrc32,
		Key:       keys,
		Val:       vals,
		Rate:      rate,
		Burst:     conf.Burst * 1000,
		Action:    action,
		Cap:       conf.Cap + 1,
		Dimension: make(map[string]*ccRateLimitItem, conf.Cap + 2),
		l:         list.New(),
		size:      0,
	}, nil
}


func (l *ccRateLimiter) Update(conf *ccRateLimitConfig) error {
	l.Lock()

	tmp, err := json.Marshal(&conf)
	if err != nil {
		l.Unlock()
		return fmt.Errorf("Marshal error. %s", err.Error())
	}

	ccrc32 := crc32.ChecksumIEEE(tmp)
	if ccrc32 == l.ConfCrc32 {
		l.Unlock()
		return nil
	}

	tmpLimiter, err := NewRateLimiter(conf)
	if err != nil {
		l.Unlock()
		return err
	}

	l = tmpLimiter

	l.Unlock()
	return nil
}

func (self *ccRateLimiter) Lookup(item map[string]string) int64 {
	self.Lock()
	defer self.Unlock()

	var (
		reqKey   string
		reqVal   string
		key      string
		buf      bytes.Buffer
		dms      *ccRateLimitItem
		ms       int64 = 0
		unixMill int64
		ok       bool
		excess   int64 = 0
	)


	buf.WriteString("limit")
	for i, k := range self.Key {
		// TODO parse ip
		if 0 == strings.Compare(k, "ip") {
			if reqVal, ok = item[k]; ok {
				ip := net.ParseIP(reqVal)
				if _, ipnet, err := net.ParseCIDR(self.Val[i]); err == nil {
					if ipnet.Contains(ip) {
						buf.WriteString("_")
						buf.WriteString(ipnet.String())
					}
				} else if 0 == strings.Compare(reqVal, self.Val[i]) {
					buf.WriteString("_")
					buf.WriteString(reqVal)
				}
			}

		} else if strings.HasSuffix(k, "_prefix") {
			reqKey = k[0 : len(k)-len("_prefix")]
			reqVal, ok = item[reqKey]
			if !(ok && strings.HasPrefix(reqVal, self.Val[i])) {
				return 0
			}
			buf.WriteString("_")
			buf.WriteString(self.Val[i])

		} else if strings.HasSuffix(k, "_suffix") {
			reqKey = k[0 : len(k)-len("_suffix")]
			reqVal, ok = item[reqKey]
			if !(ok && strings.HasSuffix(reqVal, self.Val[i])) {
				return 0
			}
			buf.WriteString("_")
			buf.WriteString(self.Val[i])

		} else {
			reqKey = k
			if reqVal, ok = item[reqKey]; ok {
				if 0 == strings.Compare(self.Val[i], "*") || 0 == strings.Compare(self.Val[i], reqVal) {
					buf.WriteString("_")
					buf.WriteString(reqVal)
				}
			}
		}
	}

	if buf.Len() < len("limit_") {
		return 0
	}

	key = buf.String()

	unixMill = time.Now().UnixNano() / 1000000
	dms, ok = self.Dimension[key]
	if ok {
		ms = unixMill - dms.Last
		if ms < -60000 {
			ms = 1

		} else if ms < 0 {
			ms = 0
		}

		excess = dms.Excess - (int64)(self.Rate)*ms/1000 + 1000
		if excess < 0 {
			excess = 0
		}

		if excess > self.Burst {
			return -1
		}

		dms.Excess = excess

		if ms > 0 {
			dms.Last = unixMill
			for lt := self.l.Front(); lt != nil; lt = lt.Next() {
				if strings.Compare(lt.Value.(string), key) == 0 {
					self.l.MoveToFront(lt)
					break
				}
			}
		}

		return dms.Excess * 1000 / self.Rate
	}

	self.Dimension[key] = &ccRateLimitItem{0, unixMill}
	if self.l.Len() == 0 {
		self.l.PushFront(key)
	} else {
		self.l.InsertAfter(key, self.l.Front())
	}

	if uint64(self.l.Len()) == self.Cap + 1 {
		e := self.l.Back()
		delete(self.Dimension, e.Value.(string))
		self.l.Remove(e)
	}
	return 0
}

func main() {

	s := `
[
    {
        "host":"test.com",
        "limit":[
            {
                "key":"host_suffix,ip,url_prefix",
                "val":"test.com,10.10.10.1/24,/test/",
                "rate":"200r/s",
                "burst":0,
                "cap":0,
                "action":"block"
            }
        ]
    },
    {
        "host":"aa.test.com",
        "limit":[
            {
                "key":"host,ip,url_prefix",
                "val":"www.test.com,10.10.10.10/24,/test/",
                "rate":"3r/s",
                "burst":300,
                "cap":0,
                "action":"block"
            },
            {
                "key":"host_suffix,ip,url",
                "val":"test.com,10.10.10.10/24,/test/index0.html",
                "rate":"3r/s",
                "burst":300,
                "cap":0,
                "action":"block"
            },
            {
                "key":"host,ip,url",
                "val":"*,*,*",
                "rate":"3r/s",
                "burst":300,
                "cap":100,
                "action":"block"
            }
        ]
    }
]
`
	var cc []*ccRateLimitOpt
	err := json.Unmarshal([]byte(s), &cc)
	if err != nil {
		fmt.Println("Unmarshal:", err.Error())
		return
	}

	limiters := make(map[string][]*ccRateLimiter, 3)
	fmt.Println(len(cc))
	for _, ll := range cc {
		fmt.Println(ll.Host)
		limits := make([]*ccRateLimiter, 0, 3)
		for _, val := range ll.LimitOpt {
			confSer, err := json.Marshal(&val)
			if err != nil {
				fmt.Println(err.Error())
			} else {
				fmt.Println(crc32.ChecksumIEEE(confSer))
			}
			fmt.Println(val.Key, val.Val, val.Rate, val.Burst, val.Action, val.Cap)

			l, err := NewRateLimiter(val)
			if err != nil {
				fmt.Println(err.Error())
				continue
			}
			limits = append(limits, l)
		}

		limiters[ll.Host] = limits
	}

	ll := limiters["test.com"]
	it := map[string]string{"host": "www.test.com", "ip": "10.10.10.10"}
	ttt := time.Now()
	var count float64 = 0.0
	for i := 0; i < 3000000; i++ {
		var delay int64 = 0
		it["url"] = fmt.Sprintf("/test/index%d.html", i)
		it["ip"] = fmt.Sprintf("10.10.10.%d", i%3)
		for _, l := range ll {
			delay = l.Lookup(it)
			if delay != 0 {
				break
			}
		}
		if delay > 0 {
			time.Sleep(time.Duration(delay) * time.Microsecond)
		}
		if delay >= 0 {
			count += 1
		}
	}

	con := time.Now().Sub(ttt).Seconds()
	fmt.Println(count, con)
	fmt.Printf("qps: %.2f\n",count / con)

}

vislee avatar Mar 16 '17 11:03 vislee