Merge remote-tracking branch 'origin/feat/1.2.0/img' into test

This commit is contained in:
LinkinStars 2023-10-17 10:31:03 +08:00
commit 86ed0ce383
5 changed files with 51 additions and 10 deletions

View File

@ -25,20 +25,29 @@ func NewRateLimitMiddleware(limitRepo *limit.LimitRepo) *RateLimitMiddleware {
// DuplicateRequestRejection detects and rejects duplicate requests // DuplicateRequestRejection detects and rejects duplicate requests
// It only works for the requests that post content. Such as add question, add answer, comment etc. // It only works for the requests that post content. Such as add question, add answer, comment etc.
func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) bool { func (rm *RateLimitMiddleware) DuplicateRequestRejection(ctx *gin.Context, req any) (reject bool, key string) {
userID := GetLoginUserIDFromContext(ctx) userID := GetLoginUserIDFromContext(ctx)
fullPath := ctx.FullPath() fullPath := ctx.FullPath()
reqJson, _ := json.Marshal(req) reqJson, _ := json.Marshal(req)
key := encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson))) key = encryption.MD5(fmt.Sprintf("%s:%s:%s", userID, fullPath, string(reqJson)))
reject, err := rm.limitRepo.CheckAndRecord(ctx, key) var err error
reject, err = rm.limitRepo.CheckAndRecord(ctx, key)
if err != nil { if err != nil {
log.Errorf("check and record rate limit error: %s", err.Error()) log.Errorf("check and record rate limit error: %s", err.Error())
return false return false, key
} }
if !reject { if !reject {
return false return false, key
} }
log.Debugf("duplicate request: [%s] %s", fullPath, string(reqJson)) log.Debugf("duplicate request: [%s] %s", fullPath, string(reqJson))
handler.HandleResponse(ctx, errors.BadRequest(reason.DuplicateRequestError), nil) handler.HandleResponse(ctx, errors.BadRequest(reason.DuplicateRequestError), nil)
return true return true, key
}
// DuplicateRequestClear clear duplicate request record
func (rm *RateLimitMiddleware) DuplicateRequestClear(ctx *gin.Context, key string) {
err := rm.limitRepo.ClearRecord(ctx, key)
if err != nil {
log.Errorf("clear rate limit error: %s", err.Error())
}
} }

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"fmt" "fmt"
"net/http"
"github.com/answerdev/answer/internal/base/handler" "github.com/answerdev/answer/internal/base/handler"
"github.com/answerdev/answer/internal/base/middleware" "github.com/answerdev/answer/internal/base/middleware"
@ -171,9 +172,16 @@ func (ac *AnswerController) Add(ctx *gin.Context) {
if handler.BindAndCheck(ctx, req) { if handler.BindAndCheck(ctx, req) {
return return
} }
if ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { reject, rejectKey := ac.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return return
} }
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
ac.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.QuestionID = uid.DeShortID(req.QuestionID) req.QuestionID = uid.DeShortID(req.QuestionID)
req.UserID = middleware.GetLoginUserIDFromContext(ctx) req.UserID = middleware.GetLoginUserIDFromContext(ctx)

View File

@ -15,6 +15,7 @@ import (
"github.com/answerdev/answer/pkg/uid" "github.com/answerdev/answer/pkg/uid"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/errors"
"net/http"
) )
// CommentController comment controller // CommentController comment controller
@ -55,9 +56,16 @@ func (cc *CommentController) AddComment(ctx *gin.Context) {
if handler.BindAndCheck(ctx, req) { if handler.BindAndCheck(ctx, req) {
return return
} }
if cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { reject, rejectKey := cc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return return
} }
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
cc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.ObjectID = uid.DeShortID(req.ObjectID) req.ObjectID = uid.DeShortID(req.ObjectID)
req.UserID = middleware.GetLoginUserIDFromContext(ctx) req.UserID = middleware.GetLoginUserIDFromContext(ctx)

View File

@ -18,6 +18,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/errors"
"net/http"
) )
// QuestionController question controller // QuestionController question controller
@ -335,9 +336,16 @@ func (qc *QuestionController) AddQuestion(ctx *gin.Context) {
if ctx.IsAborted() { if ctx.IsAborted() {
return return
} }
if qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req) { reject, rejectKey := qc.rateLimitMiddleware.DuplicateRequestRejection(ctx, req)
if reject {
return return
} }
defer func() {
// If status is not 200 means that the bad request has been returned, so the record should be cleared
if ctx.Writer.Status() != http.StatusOK {
qc.rateLimitMiddleware.DuplicateRequestClear(ctx, rejectKey)
}
}()
req.UserID = middleware.GetLoginUserIDFromContext(ctx) req.UserID = middleware.GetLoginUserIDFromContext(ctx)
canList, requireRanks, err := qc.rankService.CheckOperationPermissionsForRanks(ctx, req.UserID, []string{ canList, requireRanks, err := qc.rankService.CheckOperationPermissionsForRanks(ctx, req.UserID, []string{

View File

@ -2,10 +2,12 @@ package limit
import ( import (
"context" "context"
"fmt"
"github.com/answerdev/answer/internal/base/constant" "github.com/answerdev/answer/internal/base/constant"
"github.com/answerdev/answer/internal/base/data" "github.com/answerdev/answer/internal/base/data"
"github.com/answerdev/answer/internal/base/reason" "github.com/answerdev/answer/internal/base/reason"
"github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/errors"
"time"
) )
// LimitRepo auth repository // LimitRepo auth repository
@ -29,9 +31,15 @@ func (lr *LimitRepo) CheckAndRecord(ctx context.Context, key string) (limit bool
if exist { if exist {
return true, nil return true, nil
} }
err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key, "1", constant.RateLimitCacheTime) err = lr.data.Cache.SetString(ctx, constant.RateLimitCacheKeyPrefix+key,
fmt.Sprintf("%d", time.Now().Unix()), constant.RateLimitCacheTime)
if err != nil { if err != nil {
return false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack()
} }
return false, nil return false, nil
} }
// ClearRecord clear
func (lr *LimitRepo) ClearRecord(ctx context.Context, key string) error {
return lr.data.Cache.Del(ctx, constant.RateLimitCacheKeyPrefix+key)
}