fix(limit): remove limit record when error occur

This commit is contained in:
LinkinStars 2023-10-17 10:30:51 +08:00
parent 5ca5120cf7
commit f1ac3c4820
5 changed files with 51 additions and 10 deletions

View File

@ -25,20 +25,29 @@ func NewRateLimitMiddleware(limitRepo *limit.LimitRepo) *RateLimitMiddleware {
// DuplicateRequestRejection detects and rejects duplicate requests
// It only works for the requests that post content. Such as add question, add answer, comment etc.
func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) bool {
func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) (reject bool, key string) {
userID := GetLoginUserIDFromContext(ctx)
fullPath := ctx.FullPath()
reqJson, _ := json.Marshal(req)
key := encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson)))
reject, err := rm.limitRepo.CheckAndRecord(ctx, key)
key = encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson)))
var err error
reject, err = rm.limitRepo.CheckAndRecord(ctx, key)
if err != nil {
log.Errorf("check and record rate limit error: %s", err.Error())
return false
return false, key
}
if !reject {
return false
return false, key
}
log.Debugf("duplicate request: [%s] %s", fullPath, string(reqJson))
handler.HandleResponse(ctx, errors.BadRequest(reason.DuplicateRequestError), nil)
return true
return true, key
}
// DuplicateRequestClear clear duplicate request record
func (rm *RateLimitMiddleware) DuplicateRequestClear(ctx *gin.Context, key string) {
err := rm.limitRepo.ClearRecord(ctx, key)
if err != nil {
log.Errorf("clear rate limit error: %s", err.Error())
}
}

View File

@ -2,6 +2,7 @@ package controller
import (
"fmt"
"net/http"
"github.com/answerdev/answer/internal/base/handler"
"github.com/answerdev/answer/internal/base/middleware"
@ -171,9 +172,16 @@ func (ac *AnswerController) Add(ctx *gin.Context) {
if handler.BindAndCheck(ctx, req) {
return
}
if ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) {
reject, rejectKey := ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return
}
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
ac.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.QuestionID = uid.DeShortID(req.QuestionID)
req.UserID = middleware.GetLoginUserIDFromContext(ctx)

View File

@ -15,6 +15,7 @@ import (
"github.com/answerdev/answer/pkg/uid"
"github.com/gin-gonic/gin"
"github.com/segmentfault/pacman/errors"
"net/http"
)
// CommentController comment controller
@ -55,9 +56,16 @@ func (cc *CommentController) AddComment(ctx *gin.Context) {
if handler.BindAndCheck(ctx, req) {
return
}
if cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) {
reject, rejectKey := cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return
}
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
cc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.ObjectID = uid.DeShortID(req.ObjectID)
req.UserID = middleware.GetLoginUserIDFromContext(ctx)

View File

@ -18,6 +18,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/segmentfault/pacman/errors"
"net/http"
)
// QuestionController question controller
@ -335,9 +336,16 @@ func (qc *QuestionController) AddQuestion(ctx *gin.Context) {
if ctx.IsAborted() {
return
}
if qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) {
reject, rejectKey := qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return
}
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
qc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.UserID = middleware.GetLoginUserIDFromContext(ctx)
canList, requireRanks, err := qc.rankService.CheckOperationPermissionsForRanks(ctx, req.UserID, []string{

View File

@ -2,10 +2,12 @@ package limit
import (
"context"
"fmt"
"github.com/answerdev/answer/internal/base/constant"
"github.com/answerdev/answer/internal/base/data"
"github.com/answerdev/answer/internal/base/reason"
"github.com/segmentfault/pacman/errors"
"time"
)
// LimitRepo auth repository
@ -29,9 +31,15 @@ func (lr *LimitRepo) CheckAndRecord(ctx context.Context, key string) (limit bool
if exist {
return true, nil
}
err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key, "1", constant.RateLimitCacheTime)
err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key,
fmt.Sprintf("%d", time.Now().Unix()), constant.RateLimitCacheTime)
if err != nil {
return false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack()
}
return false, nil
}
// ClearRecord clear
func (lr *LimitRepo) ClearRecord(ctx context.Context, key string) error {
return lr.data.Cache.Del(ctx, constant.RateLimitCacheKeyPrefix+key)
}