diff --git a/internal/base/middleware/rate_limit.go b/internal/base/middleware/rate_limit.go index 837ad98e..fc46ee11 100644 --- a/internal/base/middleware/rate_limit.go +++ b/internal/base/middleware/rate_limit.go @@ -25,20 +25,29 @@ func NewRateLimitMiddleware(limitRepo *limit.LimitRepo) *RateLimitMiddleware { // DuplicateRequestRejection detects and rejects duplicate requests // It only works for the requests that post content. Such as add question, add answer, comment etc. -func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) bool { +func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) (reject bool, key string) { userID := GetLoginUserIDFromContext(ctx) fullPath := ctx.FullPath() reqJson, _ := json.Marshal(req) - key := encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson))) - reject, err := rm.limitRepo.CheckAndRecord(ctx, key) + key = encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson))) + var err error + reject, err = rm.limitRepo.CheckAndRecord(ctx, key) if err != nil { log.Errorf("check and record rate limit error: %s", err.Error()) - return false + return false, key } if !reject { - return false + return false, key } log.Debugf("duplicate request: [%s] %s", fullPath, string(reqJson)) handler.HandleResponse(ctx, errors.BadRequest(reason.DuplicateRequestError), nil) - return true + return true, key +} + +// DuplicateRequestClear clear duplicate request record +func (rm *RateLimitMiddleware) DuplicateRequestClear(ctx *gin.Context, key string) { + err := rm.limitRepo.ClearRecord(ctx, key) + if err != nil { + log.Errorf("clear rate limit error: %s", err.Error()) + } } diff --git a/internal/controller/answer_controller.go b/internal/controller/answer_controller.go index 63347c9a..aa2ecf42 100644 --- a/internal/controller/answer_controller.go +++ b/internal/controller/answer_controller.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "net/http" "github.com/answerdev/answer/internal/base/handler" "github.com/answerdev/answer/internal/base/middleware" @@ -171,9 +172,16 @@ func (ac *AnswerController) Add(ctx *gin.Context) { if handler.BindAndCheck(ctx, req) { return } - if ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { + reject, rejectKey := ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) + if reject { return } + defer func() { + // If status is not 200 means that the bad request has been returned, so the record should be cleared + if ctx.Writer.Status() != http.StatusOK { + ac.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey) + } + }() req.QuestionID = uid.DeShortID(req.QuestionID) req.UserID = middleware.GetLoginUserIDFromContext(ctx) diff --git a/internal/controller/comment_controller.go b/internal/controller/comment_controller.go index a1643426..3918b34f 100644 --- a/internal/controller/comment_controller.go +++ b/internal/controller/comment_controller.go @@ -15,6 +15,7 @@ import ( "github.com/answerdev/answer/pkg/uid" "github.com/gin-gonic/gin" "github.com/segmentfault/pacman/errors" + "net/http" ) // CommentController comment controller @@ -55,9 +56,16 @@ func (cc *CommentController) AddComment(ctx *gin.Context) { if handler.BindAndCheck(ctx, req) { return } - if cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { + reject, rejectKey := cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) + if reject { return } + defer func() { + // If status is not 200 means that the bad request has been returned, so the record should be cleared + if ctx.Writer.Status() != http.StatusOK { + cc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey) + } + }() req.ObjectID = uid.DeShortID(req.ObjectID) req.UserID = middleware.GetLoginUserIDFromContext(ctx) diff --git a/internal/controller/question_controller.go b/internal/controller/question_controller.go index 19148cf2..12605b28 100644 --- a/internal/controller/question_controller.go +++ b/internal/controller/question_controller.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin" "github.com/jinzhu/copier" "github.com/segmentfault/pacman/errors" + "net/http" ) // QuestionController question controller @@ -335,9 +336,16 @@ func (qc *QuestionController) AddQuestion(ctx *gin.Context) { if ctx.IsAborted() { return } - if qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { + reject, rejectKey := qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) + if reject { return } + defer func() { + // If status is not 200 means that the bad request has been returned, so the record should be cleared + if ctx.Writer.Status() != http.StatusOK { + qc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey) + } + }() req.UserID = middleware.GetLoginUserIDFromContext(ctx) canList, requireRanks, err := qc.rankService.CheckOperationPermissionsForRanks(ctx, req.UserID, []string{ diff --git a/internal/repo/limit/limit.go b/internal/repo/limit/limit.go index 0217d602..0249d122 100644 --- a/internal/repo/limit/limit.go +++ b/internal/repo/limit/limit.go @@ -2,10 +2,12 @@ package limit import ( "context" + "fmt" "github.com/answerdev/answer/internal/base/constant" "github.com/answerdev/answer/internal/base/data" "github.com/answerdev/answer/internal/base/reason" "github.com/segmentfault/pacman/errors" + "time" ) // LimitRepo auth repository @@ -29,9 +31,15 @@ func (lr *LimitRepo) CheckAndRecord(ctx context.Context, key string) (limit bool if exist { return true, nil } - err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key, "1", constant.RateLimitCacheTime) + err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key, + fmt.Sprintf("%d", time.Now().Unix()), constant.RateLimitCacheTime) if err != nil { return false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return false, nil } + +// ClearRecord clear +func (lr *LimitRepo) ClearRecord(ctx context.Context, key string) error { + return lr.data.Cache.Del(ctx, constant.RateLimitCacheKeyPrefix+key) +}