diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index e88de43f..59264617 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -23,7 +23,6 @@ import ( "github.com/answerdev/answer/internal/repo/captcha" "github.com/answerdev/answer/internal/repo/collection" "github.com/answerdev/answer/internal/repo/comment" - "github.com/answerdev/answer/internal/repo/common" "github.com/answerdev/answer/internal/repo/config" "github.com/answerdev/answer/internal/repo/export" "github.com/answerdev/answer/internal/repo/meta" @@ -182,9 +181,8 @@ func initApplication(debug bool, serverConf *conf.Server, dbConf *data.Database, serviceRevisionService := service.NewRevisionService(revisionRepo, userCommon, questionCommon, answerService, objService, questionRepo, answerRepo, tagRepo, tagCommonService) revisionController := controller.NewRevisionController(serviceRevisionService, rankService) rankController := controller.NewRankController(rankService) - commonRepo := common.NewCommonRepo(dataData, uniqueIDRepo) reportHandle := report_handle_admin.NewReportHandle(questionCommon, commentRepo, configRepo) - reportAdminService := report_admin.NewReportAdminService(reportRepo, userCommon, commonRepo, answerRepo, questionRepo, commentCommonRepo, reportHandle, configRepo) + reportAdminService := report_admin.NewReportAdminService(reportRepo, userCommon, answerRepo, questionRepo, commentCommonRepo, reportHandle, configRepo, objService) controller_adminReportController := controller_admin.NewReportController(reportAdminService) userAdminRepo := user.NewUserAdminRepo(dataData, authRepo) userAdminService := user_admin.NewUserAdminService(userAdminRepo, userRoleRelService, authService, userCommon, userActiveActivityRepo) diff --git a/internal/repo/activity/activity_repo.go b/internal/repo/activity/activity_repo.go index 3f8ad4b2..c1971eab 100644 --- a/internal/repo/activity/activity_repo.go +++ b/internal/repo/activity/activity_repo.go @@ -30,7 +30,7 @@ func NewActivityRepo( func (ar *activityRepo) GetObjectAllActivity(ctx context.Context, objectID string, showVote bool) ( activityList []*entity.Activity, err error) { activityList = make([]*entity.Activity, 0) - session := ar.data.DB.Desc("created_at") + session := ar.data.DB.Context(ctx).Desc("created_at") if !showVote { var activityTypeNotShown []int diff --git a/internal/repo/activity/answer_repo.go b/internal/repo/activity/answer_repo.go index 02e9c511..b13d4cb9 100644 --- a/internal/repo/activity/answer_repo.go +++ b/internal/repo/activity/answer_repo.go @@ -63,7 +63,7 @@ func NewQuestionActivityRepo( func (ar *AnswerActivityRepo) DeleteQuestion(ctx context.Context, questionID string) (err error) { questionInfo := &entity.Question{} - exist, err := ar.data.DB.Where("id = ?", questionID).Get(questionInfo) + exist, err := ar.data.DB.Context(ctx).Where("id = ?", questionID).Get(questionInfo) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -73,7 +73,7 @@ func (ar *AnswerActivityRepo) DeleteQuestion(ctx context.Context, questionID str // get all this object activity activityList := make([]*entity.Activity, 0) - session := ar.data.DB.Where("has_rank = 1") + session := ar.data.DB.Context(ctx).Where("has_rank = 1") session.Where("cancelled = ?", entity.ActivityAvailable) err = session.Find(&activityList, &entity.Activity{ObjectID: questionID}) if err != nil { @@ -86,6 +86,7 @@ func (ar *AnswerActivityRepo) DeleteQuestion(ctx context.Context, questionID str log.Infof("questionInfo %s deleted will rollback activity %d", questionID, len(activityList)) _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) for _, act := range activityList { log.Infof("user %s rollback rank %d", act.UserID, -act.Rank) _, e := ar.userRankRepo.TriggerUserRank( @@ -107,7 +108,7 @@ func (ar *AnswerActivityRepo) DeleteQuestion(ctx context.Context, questionID str // get all answers answerList := make([]*entity.Answer, 0) - err = ar.data.DB.Find(&answerList, &entity.Answer{QuestionID: questionID}) + err = ar.data.DB.Context(ctx).Find(&answerList, &entity.Answer{QuestionID: questionID}) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -154,6 +155,7 @@ func (ar *AnswerActivityRepo) AcceptAnswer(ctx context.Context, } _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) for _, addActivity := range addActivityList { existsActivity, exists, e := ar.activityRepo.GetActivity( ctx, session, answerObjID, addActivity.UserID, addActivity.ActivityType) @@ -251,6 +253,7 @@ func (ar *AnswerActivityRepo) CancelAcceptAnswer(ctx context.Context, } _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) for _, addActivity := range addActivityList { existsActivity, exists, e := ar.activityRepo.GetActivity( ctx, session, answerObjID, addActivity.UserID, addActivity.ActivityType) @@ -300,7 +303,7 @@ func (ar *AnswerActivityRepo) CancelAcceptAnswer(ctx context.Context, func (ar *AnswerActivityRepo) DeleteAnswer(ctx context.Context, answerID string) (err error) { answerInfo := &entity.Answer{} - exist, err := ar.data.DB.Where("id = ?", answerID).Get(answerInfo) + exist, err := ar.data.DB.Context(ctx).Where("id = ?", answerID).Get(answerInfo) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -310,7 +313,7 @@ func (ar *AnswerActivityRepo) DeleteAnswer(ctx context.Context, answerID string) // get all this object activity activityList := make([]*entity.Activity, 0) - session := ar.data.DB.Where("has_rank = 1") + session := ar.data.DB.Context(ctx).Where("has_rank = 1") session.Where("cancelled = ?", entity.ActivityAvailable) err = session.Find(&activityList, &entity.Activity{ObjectID: answerID}) if err != nil { @@ -323,6 +326,7 @@ func (ar *AnswerActivityRepo) DeleteAnswer(ctx context.Context, answerID string) log.Infof("answerInfo %s deleted will rollback activity %d", answerID, len(activityList)) _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) for _, act := range activityList { log.Infof("user %s rollback rank %d", act.UserID, -act.Rank) _, e := ar.userRankRepo.TriggerUserRank( diff --git a/internal/repo/activity/follow_repo.go b/internal/repo/activity/follow_repo.go index bb1a0a0c..6cc1bdca 100644 --- a/internal/repo/activity/follow_repo.go +++ b/internal/repo/activity/follow_repo.go @@ -45,6 +45,7 @@ func (ar *FollowRepo) Follow(ctx context.Context, objectID, userID string) error } _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) var ( existsActivity entity.Activity has bool @@ -107,6 +108,7 @@ func (ar *FollowRepo) FollowCancel(ctx context.Context, objectID, userID string) } _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) var ( existsActivity entity.Activity has bool diff --git a/internal/repo/activity/user_active_repo.go b/internal/repo/activity/user_active_repo.go index 1b75adac..b70a0dac 100644 --- a/internal/repo/activity/user_active_repo.go +++ b/internal/repo/activity/user_active_repo.go @@ -44,7 +44,7 @@ func NewUserActiveActivityRepo( // UserActive accept other answer func (ar *UserActiveActivityRepo) UserActive(ctx context.Context, userID string) (err error) { _, err = ar.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { - + session = session.Context(ctx) activityType, err := ar.configRepo.GetConfigType(UserActivated) if err != nil { return nil, err diff --git a/internal/repo/activity/vote_repo.go b/internal/repo/activity/vote_repo.go index 1303f3c7..ba0d774d 100644 --- a/internal/repo/activity/vote_repo.go +++ b/internal/repo/activity/vote_repo.go @@ -75,6 +75,7 @@ func (vr *VoteRepo) vote(ctx context.Context, objectID string, userID, objectUse sendInboxNotification := false upVote := false _, err = vr.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) result = nil for _, action := range actions { var ( @@ -185,6 +186,7 @@ func (vr *VoteRepo) voteCancel(ctx context.Context, objectID string, userID, obj resp = &schema.VoteResp{} notificationUserIDs := make([]string, 0) _, err = vr.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) for _, action := range actions { var ( existsActivity entity.Activity @@ -362,7 +364,7 @@ func (vr *VoteRepo) GetVoteResultByObjectId(ctx context.Context, objectID string activityType, _, _, _ = vr.activityRepo.GetActivityTypeByObjID(ctx, objectID, action) - votes, err = vr.data.DB.Where(builder.Eq{"object_id": objectID}). + votes, err = vr.data.DB.Context(ctx).Where(builder.Eq{"object_id": objectID}). And(builder.Eq{"activity_type": activityType}). And(builder.Eq{"cancelled": 0}). Count(&activity) @@ -389,7 +391,7 @@ func (vr *VoteRepo) ListUserVotes( req schema.GetVoteWithPageReq, activityTypes []int, ) (voteList []entity.Activity, total int64, err error) { - session := vr.data.DB.NewSession() + session := vr.data.DB.Context(ctx) cond := builder. And( builder.Eq{"user_id": userID}, diff --git a/internal/repo/activity_common/activity_repo.go b/internal/repo/activity_common/activity_repo.go index 008eeecb..3717ae0e 100644 --- a/internal/repo/activity_common/activity_repo.go +++ b/internal/repo/activity_common/activity_repo.go @@ -87,7 +87,7 @@ func (ar *ActivityRepo) GetActivity(ctx context.Context, session *xorm.Session, func (ar *ActivityRepo) GetUserIDObjectIDActivitySum(ctx context.Context, userID, objectID string) (int, error) { sum := &entity.ActivityRankSum{} - _, err := ar.data.DB.Table(entity.Activity{}.TableName()). + _, err := ar.data.DB.Context(ctx).Table(entity.Activity{}.TableName()). Select("sum(`rank`) as `rank`"). Where("user_id =?", userID). And("object_id = ?", objectID). @@ -102,7 +102,7 @@ func (ar *ActivityRepo) GetUserIDObjectIDActivitySum(ctx context.Context, userID // AddActivity add activity func (ar *ActivityRepo) AddActivity(ctx context.Context, activity *entity.Activity) (err error) { - _, err = ar.data.DB.Insert(activity) + _, err = ar.data.DB.Context(ctx).Insert(activity) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -113,7 +113,7 @@ func (ar *ActivityRepo) AddActivity(ctx context.Context, activity *entity.Activi func (ar *ActivityRepo) GetUsersWhoHasGainedTheMostReputation( ctx context.Context, startTime, endTime time.Time, limit int) (rankStat []*entity.ActivityUserRankStat, err error) { rankStat = make([]*entity.ActivityUserRankStat, 0) - session := ar.data.DB.Select("user_id, SUM(`rank`) AS rank_amount").Table("activity") + session := ar.data.DB.Context(ctx).Select("user_id, SUM(`rank`) AS rank_amount").Table("activity") session.Where("has_rank = 1 AND cancelled = 0") session.Where("created_at >= ?", startTime) session.Where("created_at <= ?", endTime) @@ -140,7 +140,7 @@ func (ar *ActivityRepo) GetUsersWhoHasVoteMost( } } - session := ar.data.DB.Select("user_id, COUNT(*) AS vote_count").Table("activity") + session := ar.data.DB.Context(ctx).Select("user_id, COUNT(*) AS vote_count").Table("activity") session.Where("cancelled = 0") session.In("activity_type", actIDs) session.Where("created_at >= ?", startTime) diff --git a/internal/repo/activity_common/follow.go b/internal/repo/activity_common/follow.go index 5f7b0ec0..0ab331b6 100644 --- a/internal/repo/activity_common/follow.go +++ b/internal/repo/activity_common/follow.go @@ -41,19 +41,19 @@ func (ar *FollowRepo) GetFollowAmount(ctx context.Context, objectID string) (fol switch objectType { case "question": model := &entity.Question{} - _, err = ar.data.DB.Where("id = ?", objectID).Cols("`follow_count`").Get(model) + _, err = ar.data.DB.Context(ctx).Where("id = ?", objectID).Cols("`follow_count`").Get(model) if err == nil { follows = int(model.FollowCount) } case "user": model := &entity.User{} - _, err = ar.data.DB.Where("id = ?", objectID).Cols("`follow_count`").Get(model) + _, err = ar.data.DB.Context(ctx).Where("id = ?", objectID).Cols("`follow_count`").Get(model) if err == nil { follows = int(model.FollowCount) } case "tag": model := &entity.Tag{} - _, err = ar.data.DB.Where("id = ?", objectID).Cols("`follow_count`").Get(model) + _, err = ar.data.DB.Context(ctx).Where("id = ?", objectID).Cols("`follow_count`").Get(model) if err == nil { follows = int(model.FollowCount) } @@ -79,7 +79,7 @@ func (ar *FollowRepo) GetFollowUserIDs(ctx context.Context, objectID string) (us } userIDs = make([]string, 0) - session := ar.data.DB.Select("user_id") + session := ar.data.DB.Context(ctx).Select("user_id") session.Table(entity.Activity{}.TableName()) session.Where("object_id = ?", objectID) session.Where("activity_type = ?", activityType) @@ -98,7 +98,7 @@ func (ar *FollowRepo) GetFollowIDs(ctx context.Context, userID, objectKey string if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } - session := ar.data.DB.Select("object_id") + session := ar.data.DB.Context(ctx).Select("object_id") session.Table(entity.Activity{}.TableName()) session.Where("user_id = ? AND activity_type = ?", userID, activityType) session.Where("cancelled = 0") @@ -110,14 +110,14 @@ func (ar *FollowRepo) GetFollowIDs(ctx context.Context, userID, objectKey string } // IsFollowed check user if follow object or not -func (ar *FollowRepo) IsFollowed(userID, objectID string) (bool, error) { +func (ar *FollowRepo) IsFollowed(ctx context.Context, userID, objectID string) (bool, error) { activityType, _, _, err := ar.activityRepo.GetActivityTypeByObjID(context.TODO(), objectID, "follow") if err != nil { return false, err } at := &entity.Activity{} - has, err := ar.data.DB.Where("user_id = ? AND object_id = ? AND activity_type = ?", userID, objectID, activityType).Get(at) + has, err := ar.data.DB.Context(ctx).Where("user_id = ? AND object_id = ? AND activity_type = ?", userID, objectID, activityType).Get(at) if err != nil { return false, err } diff --git a/internal/repo/activity_common/vote.go b/internal/repo/activity_common/vote.go index 16f73fc4..a8589fc4 100644 --- a/internal/repo/activity_common/vote.go +++ b/internal/repo/activity_common/vote.go @@ -31,7 +31,7 @@ func (vr *VoteRepo) GetVoteStatus(ctx context.Context, objectID, userID string) if err != nil { return "" } - has, err := vr.data.DB.Where("object_id =? AND cancelled=0 AND activity_type=? AND user_id=?", objectID, activityType, userID).Get(at) + has, err := vr.data.DB.Context(ctx).Where("object_id =? AND cancelled=0 AND activity_type=? AND user_id=?", objectID, activityType, userID).Get(at) if err != nil { return "" } @@ -44,7 +44,7 @@ func (vr *VoteRepo) GetVoteStatus(ctx context.Context, objectID, userID string) func (vr *VoteRepo) GetVoteCount(ctx context.Context, activityTypes []int) (count int64, err error) { list := make([]*entity.Activity, 0) - count, err = vr.data.DB.Where("cancelled =0").In("activity_type", activityTypes).FindAndCount(&list) + count, err = vr.data.DB.Context(ctx).Where("cancelled =0").In("activity_type", activityTypes).FindAndCount(&list) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/answer/answer_repo.go b/internal/repo/answer/answer_repo.go index 084231d2..e126b3fc 100644 --- a/internal/repo/answer/answer_repo.go +++ b/internal/repo/answer/answer_repo.go @@ -53,7 +53,7 @@ func (ar *answerRepo) AddAnswer(ctx context.Context, answer *entity.Answer) (err return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } answer.ID = ID - _, err = ar.data.DB.Insert(answer) + _, err = ar.data.DB.Context(ctx).Insert(answer) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -70,7 +70,7 @@ func (ar *answerRepo) RemoveAnswer(ctx context.Context, id string) (err error) { ID: id, Status: entity.AnswerStatusDeleted, } - _, err = ar.data.DB.Where("id = ?", id).Cols("status").Update(answer) + _, err = ar.data.DB.Context(ctx).Where("id = ?", id).Cols("status").Update(answer) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -81,7 +81,7 @@ func (ar *answerRepo) RemoveAnswer(ctx context.Context, id string) (err error) { func (ar *answerRepo) UpdateAnswer(ctx context.Context, answer *entity.Answer, Colar []string) (err error) { answer.ID = uid.DeShortID(answer.ID) answer.QuestionID = uid.DeShortID(answer.QuestionID) - _, err = ar.data.DB.ID(answer.ID).Cols(Colar...).Update(answer) + _, err = ar.data.DB.Context(ctx).ID(answer.ID).Cols(Colar...).Update(answer) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -92,7 +92,7 @@ func (ar *answerRepo) UpdateAnswerStatus(ctx context.Context, answer *entity.Ans now := time.Now() answer.ID = uid.DeShortID(answer.ID) answer.UpdatedAt = now - _, err = ar.data.DB.Where("id =?", answer.ID).Cols("status", "updated_at").Update(answer) + _, err = ar.data.DB.Context(ctx).Where("id =?", answer.ID).Cols("status", "updated_at").Update(answer) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -105,7 +105,7 @@ func (ar *answerRepo) GetAnswer(ctx context.Context, id string) ( ) { id = uid.DeShortID(id) answer = &entity.Answer{} - exist, err = ar.data.DB.ID(id).Get(answer) + exist, err = ar.data.DB.Context(ctx).ID(id).Get(answer) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -118,7 +118,7 @@ func (ar *answerRepo) GetAnswer(ctx context.Context, id string) ( // GetQuestionCount func (ar *answerRepo) GetAnswerCount(ctx context.Context) (count int64, err error) { list := make([]*entity.Answer, 0) - count, err = ar.data.DB.Where("status = ?", entity.AnswerStatusAvailable).FindAndCount(&list) + count, err = ar.data.DB.Context(ctx).Where("status = ?", entity.AnswerStatusAvailable).FindAndCount(&list) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -130,7 +130,7 @@ func (ar *answerRepo) GetAnswerList(ctx context.Context, answer *entity.Answer) answerList = make([]*entity.Answer, 0) answer.ID = uid.DeShortID(answer.ID) answer.QuestionID = uid.DeShortID(answer.QuestionID) - err = ar.data.DB.Find(answerList, answer) + err = ar.data.DB.Context(ctx).Find(answerList, answer) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -146,7 +146,7 @@ func (ar *answerRepo) GetAnswerPage(ctx context.Context, page, pageSize int, ans answer.ID = uid.DeShortID(answer.ID) answer.QuestionID = uid.DeShortID(answer.QuestionID) answerList = make([]*entity.Answer, 0) - total, err = pager.Help(page, pageSize, answerList, answer, ar.data.DB.NewSession()) + total, err = pager.Help(page, pageSize, answerList, answer, ar.data.DB.Context(ctx)) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -169,13 +169,13 @@ func (ar *answerRepo) UpdateAccepted(ctx context.Context, id string, questionID data.ID = id data.Accepted = schema.AnswerAcceptedFailed - _, err := ar.data.DB.Where("question_id =?", questionID).Cols("adopted").Update(&data) + _, err := ar.data.DB.Context(ctx).Where("question_id =?", questionID).Cols("adopted").Update(&data) if err != nil { return err } if id != "0" { data.Accepted = schema.AnswerAcceptedEnable - _, err = ar.data.DB.Where("id = ?", id).Cols("adopted").Update(&data) + _, err = ar.data.DB.Context(ctx).Where("id = ?", id).Cols("adopted").Update(&data) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -187,7 +187,7 @@ func (ar *answerRepo) UpdateAccepted(ctx context.Context, id string, questionID func (ar *answerRepo) GetByID(ctx context.Context, id string) (*entity.Answer, bool, error) { var resp entity.Answer id = uid.DeShortID(id) - has, err := ar.data.DB.Where("id =? ", id).Get(&resp) + has, err := ar.data.DB.Context(ctx).Where("id =? ", id).Get(&resp) if err != nil { return &resp, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -199,7 +199,7 @@ func (ar *answerRepo) GetByID(ctx context.Context, id string) (*entity.Answer, b func (ar *answerRepo) GetByUserIDQuestionID(ctx context.Context, userID string, questionID string) (*entity.Answer, bool, error) { questionID = uid.DeShortID(questionID) var resp entity.Answer - has, err := ar.data.DB.Where("question_id =? and user_id = ?", questionID, userID).Get(&resp) + has, err := ar.data.DB.Context(ctx).Where("question_id =? and user_id = ?", questionID, userID).Get(&resp) if err != nil { return &resp, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -226,7 +226,7 @@ func (ar *answerRepo) SearchList(ctx context.Context, search *entity.AnswerSearc search.PageSize = constant.DefaultPageSize } offset := search.Page * search.PageSize - session := ar.data.DB.Where("") + session := ar.data.DB.Context(ctx).Where("") if search.QuestionID != "" { session = session.And("question_id = ?", search.QuestionID) @@ -262,7 +262,7 @@ func (ar *answerRepo) AdminSearchList(ctx context.Context, search *entity.AdminA var ( count int64 err error - session = ar.data.DB.Table([]string{entity.Answer{}.TableName(), "a"}).Select("a.*") + session = ar.data.DB.Context(ctx).Table([]string{entity.Answer{}.TableName(), "a"}).Select("a.*") ) if search.QuestionID != "" { search.QuestionID = uid.DeShortID(search.QuestionID) diff --git a/internal/repo/collection/collection_group_repo.go b/internal/repo/collection/collection_group_repo.go index 8f915bb6..df75de0c 100644 --- a/internal/repo/collection/collection_group_repo.go +++ b/internal/repo/collection/collection_group_repo.go @@ -26,7 +26,7 @@ func NewCollectionGroupRepo(data *data.Data) service.CollectionGroupRepo { // AddCollectionGroup add collection group func (cr *collectionGroupRepo) AddCollectionGroup(ctx context.Context, collectionGroup *entity.CollectionGroup) (err error) { - _, err = cr.data.DB.Insert(collectionGroup) + _, err = cr.data.DB.Context(ctx).Insert(collectionGroup) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -40,7 +40,7 @@ func (cr *collectionGroupRepo) AddCollectionDefaultGroup(ctx context.Context, us DefaultGroup: schema.CGDefault, UserID: userID, } - _, err = cr.data.DB.Insert(defaultGroup) + _, err = cr.data.DB.Context(ctx).Insert(defaultGroup) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return @@ -51,7 +51,7 @@ func (cr *collectionGroupRepo) AddCollectionDefaultGroup(ctx context.Context, us // UpdateCollectionGroup update collection group func (cr *collectionGroupRepo) UpdateCollectionGroup(ctx context.Context, collectionGroup *entity.CollectionGroup, cols []string) (err error) { - _, err = cr.data.DB.ID(collectionGroup.ID).Cols(cols...).Update(collectionGroup) + _, err = cr.data.DB.Context(ctx).ID(collectionGroup.ID).Cols(cols...).Update(collectionGroup) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -63,7 +63,7 @@ func (cr *collectionGroupRepo) GetCollectionGroup(ctx context.Context, id string collectionGroup *entity.CollectionGroup, exist bool, err error, ) { collectionGroup = &entity.CollectionGroup{} - exist, err = cr.data.DB.ID(id).Get(collectionGroup) + exist, err = cr.data.DB.Context(ctx).ID(id).Get(collectionGroup) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -74,7 +74,7 @@ func (cr *collectionGroupRepo) GetCollectionGroup(ctx context.Context, id string func (cr *collectionGroupRepo) GetCollectionGroupPage(ctx context.Context, page, pageSize int, collectionGroup *entity.CollectionGroup) (collectionGroupList []*entity.CollectionGroup, total int64, err error) { collectionGroupList = make([]*entity.CollectionGroup, 0) - session := cr.data.DB.NewSession() + session := cr.data.DB.Context(ctx) if collectionGroup.UserID != "" && collectionGroup.UserID != "0" { session = session.Where("user_id = ?", collectionGroup.UserID) } @@ -87,7 +87,7 @@ func (cr *collectionGroupRepo) GetCollectionGroupPage(ctx context.Context, page, func (cr *collectionGroupRepo) GetDefaultID(ctx context.Context, userID string) (collectionGroup *entity.CollectionGroup, has bool, err error) { collectionGroup = &entity.CollectionGroup{} - has, err = cr.data.DB.Where("user_id =? and default_group = ?", userID, schema.CGDefault).Get(collectionGroup) + has, err = cr.data.DB.Context(ctx).Where("user_id =? and default_group = ?", userID, schema.CGDefault).Get(collectionGroup) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return diff --git a/internal/repo/collection/collection_repo.go b/internal/repo/collection/collection_repo.go index 7fcf9335..6c05a38d 100644 --- a/internal/repo/collection/collection_repo.go +++ b/internal/repo/collection/collection_repo.go @@ -32,6 +32,7 @@ func NewCollectionRepo(data *data.Data, uniqueIDRepo unique.UniqueIDRepo) collec func (cr *collectionRepo) AddCollection(ctx context.Context, collection *entity.Collection) (err error) { needAdd := false _, err = cr.data.DB.Transaction(func(session *xorm.Session) (result any, err error) { + session = session.Context(ctx) var has bool dbcollection := &entity.Collection{} result = nil @@ -52,7 +53,7 @@ func (cr *collectionRepo) AddCollection(ctx context.Context, collection *entity. id, err := cr.uniqueIDRepo.GenUniqueIDStr(ctx, collection.TableName()) if err == nil { collection.ID = id - _, err = cr.data.DB.Insert(collection) + _, err = cr.data.DB.Context(ctx).Insert(collection) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -64,7 +65,7 @@ func (cr *collectionRepo) AddCollection(ctx context.Context, collection *entity. // RemoveCollection delete collection func (cr *collectionRepo) RemoveCollection(ctx context.Context, id string) (err error) { - _, err = cr.data.DB.Where("id =?", id).Delete(&entity.Collection{}) + _, err = cr.data.DB.Context(ctx).Where("id =?", id).Delete(&entity.Collection{}) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -73,14 +74,14 @@ func (cr *collectionRepo) RemoveCollection(ctx context.Context, id string) (err // UpdateCollection update collection func (cr *collectionRepo) UpdateCollection(ctx context.Context, collection *entity.Collection, cols []string) (err error) { - _, err = cr.data.DB.ID(collection.ID).Cols(cols...).Update(collection) + _, err = cr.data.DB.Context(ctx).ID(collection.ID).Cols(cols...).Update(collection) return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } // GetCollection get collection one func (cr *collectionRepo) GetCollection(ctx context.Context, id int) (collection *entity.Collection, exist bool, err error) { collection = &entity.Collection{} - exist, err = cr.data.DB.ID(id).Get(collection) + exist, err = cr.data.DB.Context(ctx).ID(id).Get(collection) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -90,7 +91,7 @@ func (cr *collectionRepo) GetCollection(ctx context.Context, id int) (collection // GetCollectionList get collection list all func (cr *collectionRepo) GetCollectionList(ctx context.Context, collection *entity.Collection) (collectionList []*entity.Collection, err error) { collectionList = make([]*entity.Collection, 0) - err = cr.data.DB.Find(collectionList, collection) + err = cr.data.DB.Context(ctx).Find(collectionList, collection) err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return } @@ -98,7 +99,7 @@ func (cr *collectionRepo) GetCollectionList(ctx context.Context, collection *ent // GetOneByObjectIDAndUser get one by object TagID and user func (cr *collectionRepo) GetOneByObjectIDAndUser(ctx context.Context, userID string, objectID string) (collection *entity.Collection, exist bool, err error) { collection = &entity.Collection{} - exist, err = cr.data.DB.Where("user_id = ? and object_id = ?", userID, objectID).Get(collection) + exist, err = cr.data.DB.Context(ctx).Where("user_id = ? and object_id = ?", userID, objectID).Get(collection) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -108,7 +109,7 @@ func (cr *collectionRepo) GetOneByObjectIDAndUser(ctx context.Context, userID st // SearchByObjectIDsAndUser search by object IDs and user func (cr *collectionRepo) SearchByObjectIDsAndUser(ctx context.Context, userID string, objectIDs []string) ([]*entity.Collection, error) { collectionList := make([]*entity.Collection, 0) - err := cr.data.DB.Where("user_id = ?", userID).In("object_id", objectIDs).Find(&collectionList) + err := cr.data.DB.Context(ctx).Where("user_id = ?", userID).In("object_id", objectIDs).Find(&collectionList) if err != nil { return collectionList, err } @@ -118,7 +119,7 @@ func (cr *collectionRepo) SearchByObjectIDsAndUser(ctx context.Context, userID s // CountByObjectID count by object TagID func (cr *collectionRepo) CountByObjectID(ctx context.Context, objectID string) (total int64, err error) { collection := &entity.Collection{} - total, err = cr.data.DB.Where("object_id = ?", objectID).Count(collection) + total, err = cr.data.DB.Context(ctx).Where("object_id = ?", objectID).Count(collection) if err != nil { return 0, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -129,7 +130,7 @@ func (cr *collectionRepo) CountByObjectID(ctx context.Context, objectID string) func (cr *collectionRepo) GetCollectionPage(ctx context.Context, page, pageSize int, collection *entity.Collection) (collectionList []*entity.Collection, total int64, err error) { collectionList = make([]*entity.Collection, 0) - session := cr.data.DB.NewSession() + session := cr.data.DB.Context(ctx) if collection.UserID != "" && collection.UserID != "0" { session = session.Where("user_id = ?", collection.UserID) } @@ -172,7 +173,7 @@ func (cr *collectionRepo) SearchList(ctx context.Context, search *entity.Collect search.PageSize = constant.DefaultPageSize } offset := search.Page * search.PageSize - session := cr.data.DB.Where("") + session := cr.data.DB.Context(ctx).Where("") if len(search.UserID) > 0 { session = session.And("user_id = ?", search.UserID) } else { diff --git a/internal/repo/comment/comment_repo.go b/internal/repo/comment/comment_repo.go index 992e6a2f..d3b3d8ac 100644 --- a/internal/repo/comment/comment_repo.go +++ b/internal/repo/comment/comment_repo.go @@ -41,7 +41,7 @@ func (cr *commentRepo) AddComment(ctx context.Context, comment *entity.Comment) if err != nil { return err } - _, err = cr.data.DB.Insert(comment) + _, err = cr.data.DB.Context(ctx).Insert(comment) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -50,7 +50,7 @@ func (cr *commentRepo) AddComment(ctx context.Context, comment *entity.Comment) // RemoveComment delete comment func (cr *commentRepo) RemoveComment(ctx context.Context, commentID string) (err error) { - session := cr.data.DB.ID(commentID) + session := cr.data.DB.Context(ctx).ID(commentID) _, err = session.Update(&entity.Comment{Status: entity.CommentStatusDeleted}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -60,7 +60,7 @@ func (cr *commentRepo) RemoveComment(ctx context.Context, commentID string) (err // UpdateComment update comment func (cr *commentRepo) UpdateComment(ctx context.Context, comment *entity.Comment) (err error) { - _, err = cr.data.DB.ID(comment.ID).Where("user_id = ?", comment.UserID).Update(comment) + _, err = cr.data.DB.Context(ctx).ID(comment.ID).Where("user_id = ?", comment.UserID).Update(comment) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -72,7 +72,7 @@ func (cr *commentRepo) GetComment(ctx context.Context, commentID string) ( comment *entity.Comment, exist bool, err error, ) { comment = &entity.Comment{} - exist, err = cr.data.DB.ID(commentID).Get(comment) + exist, err = cr.data.DB.Context(ctx).ID(commentID).Get(comment) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -81,7 +81,7 @@ func (cr *commentRepo) GetComment(ctx context.Context, commentID string) ( func (cr *commentRepo) GetCommentCount(ctx context.Context) (count int64, err error) { list := make([]*entity.Comment, 0) - count, err = cr.data.DB.Where("status = ?", entity.CommentStatusAvailable).FindAndCount(&list) + count, err = cr.data.DB.Context(ctx).Where("status = ?", entity.CommentStatusAvailable).FindAndCount(&list) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -94,7 +94,7 @@ func (cr *commentRepo) GetCommentPage(ctx context.Context, commentQuery *comment ) { commentList = make([]*entity.Comment, 0) - session := cr.data.DB.NewSession() + session := cr.data.DB.Context(ctx) session.OrderBy(commentQuery.GetOrderBy()) session.Where("status = ?", entity.CommentStatusAvailable) diff --git a/internal/repo/common/common.go b/internal/repo/common/common.go deleted file mode 100644 index 59d1f304..00000000 --- a/internal/repo/common/common.go +++ /dev/null @@ -1,93 +0,0 @@ -package common - -import ( - "github.com/answerdev/answer/internal/base/data" - "github.com/answerdev/answer/internal/base/reason" - "github.com/answerdev/answer/internal/entity" - "github.com/answerdev/answer/internal/service/unique" - "github.com/answerdev/answer/pkg/obj" - "github.com/segmentfault/pacman/errors" - "github.com/segmentfault/pacman/log" -) - -type CommonRepo struct { - data *data.Data - uniqueIDRepo unique.UniqueIDRepo -} - -func NewCommonRepo(data *data.Data, uniqueIDRepo unique.UniqueIDRepo) *CommonRepo { - return &CommonRepo{ - data: data, - uniqueIDRepo: uniqueIDRepo, - } -} - -// GetRootObjectID get root object ID -func (cr *CommonRepo) GetRootObjectID(objectID string) (rootObjectID string, err error) { - var ( - exist bool - objectType string - answer = entity.Answer{} - comment = entity.Comment{} - ) - - objectType, err = obj.GetObjectTypeStrByObjectID(objectID) - switch objectType { - case "answer": - exist, err = cr.data.DB.ID(objectID).Get(&answer) - if !exist { - err = errors.BadRequest(reason.ObjectNotFound) - } - case "comment": - exist, _ = cr.data.DB.ID(objectID).Get(&comment) - if !exist { - err = errors.BadRequest(reason.ObjectNotFound) - } else { - _, err = cr.GetRootObjectID(comment.ObjectID) - } - default: - rootObjectID = objectID - } - return -} - -// GetObjectIDMap get object ID map from object id -func (cr *CommonRepo) GetObjectIDMap(objectID string) (objectIDMap map[string]string, err error) { - var ( - exist bool - ID, - objectType string - answer = entity.Answer{} - comment = entity.Comment{} - ) - - objectIDMap = map[string]string{} - // 10070000000000450 - objectType, err = obj.GetObjectTypeStrByObjectID(objectID) - if err != nil { - log.Error("get report object type:", objectID, ",err:", err) - return - } - switch objectType { - case "answer": - exist, _ = cr.data.DB.ID(objectID).Get(&answer) - if !exist { - err = errors.BadRequest(reason.ObjectNotFound) - } else { - objectIDMap, err = cr.GetObjectIDMap(answer.QuestionID) - ID = answer.ID - } - case "comment": - exist, _ = cr.data.DB.ID(objectID).Get(&comment) - if !exist { - err = errors.BadRequest(reason.ObjectNotFound) - } else { - objectIDMap, err = cr.GetObjectIDMap(comment.ObjectID) - ID = comment.ID - } - case "question": - ID = objectID - } - objectIDMap[objectType] = ID - return -} diff --git a/internal/repo/config/config_repo.go b/internal/repo/config/config_repo.go index 329e2b65..77d63773 100644 --- a/internal/repo/config/config_repo.go +++ b/internal/repo/config/config_repo.go @@ -1,6 +1,7 @@ package config import ( + "context" "encoding/json" "fmt" "sync" @@ -40,7 +41,7 @@ func (cr *configRepo) init() { cr.mu.Lock() defer cr.mu.Unlock() rows := &[]entity.Config{} - err := cr.data.DB.Find(rows) + err := cr.data.DB.Context(context.TODO()).Find(rows) if err == nil { for _, row := range *rows { Key2ValueMapping[row.Key] = row.Value @@ -128,9 +129,9 @@ func (cr *configRepo) GetJsonConfigByIDAndSetToObject(id int, object any) (err e } // SetConfig set config -func (cr *configRepo) SetConfig(key, value string) (err error) { +func (cr *configRepo) SetConfig(ctx context.Context, key, value string) (err error) { id := Key2IDMapping[key] - _, err = cr.data.DB.ID(id).Update(&entity.Config{Value: value}) + _, err = cr.data.DB.Context(ctx).ID(id).Update(&entity.Config{Value: value}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } else { diff --git a/internal/repo/meta/meta_repo.go b/internal/repo/meta/meta_repo.go index 415ad428..fefa6748 100644 --- a/internal/repo/meta/meta_repo.go +++ b/internal/repo/meta/meta_repo.go @@ -25,7 +25,7 @@ func NewMetaRepo(data *data.Data) meta.MetaRepo { // AddMeta add meta func (mr *metaRepo) AddMeta(ctx context.Context, meta *entity.Meta) (err error) { - _, err = mr.data.DB.Insert(meta) + _, err = mr.data.DB.Context(ctx).Insert(meta) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -34,7 +34,7 @@ func (mr *metaRepo) AddMeta(ctx context.Context, meta *entity.Meta) (err error) // RemoveMeta delete meta func (mr *metaRepo) RemoveMeta(ctx context.Context, id int) (err error) { - _, err = mr.data.DB.ID(id).Delete(&entity.Meta{}) + _, err = mr.data.DB.Context(ctx).ID(id).Delete(&entity.Meta{}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -43,7 +43,7 @@ func (mr *metaRepo) RemoveMeta(ctx context.Context, id int) (err error) { // UpdateMeta update meta func (mr *metaRepo) UpdateMeta(ctx context.Context, meta *entity.Meta) (err error) { - _, err = mr.data.DB.ID(meta.ID).Update(meta) + _, err = mr.data.DB.Context(ctx).ID(meta.ID).Update(meta) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -54,7 +54,7 @@ func (mr *metaRepo) UpdateMeta(ctx context.Context, meta *entity.Meta) (err erro func (mr *metaRepo) GetMetaByObjectIdAndKey(ctx context.Context, objectID, key string) ( meta *entity.Meta, exist bool, err error) { meta = &entity.Meta{} - exist, err = mr.data.DB.Where(builder.Eq{"object_id": objectID}.And(builder.Eq{"`key`": key})).Desc("created_at").Get(meta) + exist, err = mr.data.DB.Context(ctx).Where(builder.Eq{"object_id": objectID}.And(builder.Eq{"`key`": key})).Desc("created_at").Get(meta) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -64,7 +64,7 @@ func (mr *metaRepo) GetMetaByObjectIdAndKey(ctx context.Context, objectID, key s // GetMetaList get meta list all func (mr *metaRepo) GetMetaList(ctx context.Context, meta *entity.Meta) (metaList []*entity.Meta, err error) { metaList = make([]*entity.Meta, 0) - err = mr.data.DB.Find(&metaList, meta) + err = mr.data.DB.Context(ctx).Find(&metaList, meta) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/notification/notification_repo.go b/internal/repo/notification/notification_repo.go index 994113fc..71d5a0a6 100644 --- a/internal/repo/notification/notification_repo.go +++ b/internal/repo/notification/notification_repo.go @@ -29,7 +29,7 @@ func NewNotificationRepo(data *data.Data) notficationcommon.NotificationRepo { // AddNotification add notification func (nr *notificationRepo) AddNotification(ctx context.Context, notification *entity.Notification) (err error) { notification.ObjectID = uid.DeShortID(notification.ObjectID) - _, err = nr.data.DB.Insert(notification) + _, err = nr.data.DB.Context(ctx).Insert(notification) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -40,7 +40,7 @@ func (nr *notificationRepo) UpdateNotificationContent(ctx context.Context, notif now := time.Now() notification.UpdatedAt = now notification.ObjectID = uid.DeShortID(notification.ObjectID) - _, err = nr.data.DB.Where("id =?", notification.ID).Cols("content", "updated_at").Update(notification) + _, err = nr.data.DB.Context(ctx).Where("id =?", notification.ID).Cols("content", "updated_at").Update(notification) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -50,7 +50,7 @@ func (nr *notificationRepo) UpdateNotificationContent(ctx context.Context, notif func (nr *notificationRepo) ClearUnRead(ctx context.Context, userID string, notificationType int) (err error) { info := &entity.Notification{} info.IsRead = schema.NotificationRead - _, err = nr.data.DB.Where("user_id =?", userID).And("type =?", notificationType).Cols("is_read").Update(info) + _, err = nr.data.DB.Context(ctx).Where("user_id =?", userID).And("type =?", notificationType).Cols("is_read").Update(info) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -60,7 +60,7 @@ func (nr *notificationRepo) ClearUnRead(ctx context.Context, userID string, noti func (nr *notificationRepo) ClearIDUnRead(ctx context.Context, userID string, id string) (err error) { info := &entity.Notification{} info.IsRead = schema.NotificationRead - _, err = nr.data.DB.Where("user_id =?", userID).And("id =?", id).Cols("is_read").Update(info) + _, err = nr.data.DB.Context(ctx).Where("user_id =?", userID).And("id =?", id).Cols("is_read").Update(info) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -69,7 +69,7 @@ func (nr *notificationRepo) ClearIDUnRead(ctx context.Context, userID string, id func (nr *notificationRepo) GetById(ctx context.Context, id string) (*entity.Notification, bool, error) { info := &entity.Notification{} - exist, err := nr.data.DB.Where("id = ? ", id).Get(info) + exist, err := nr.data.DB.Context(ctx).Where("id = ? ", id).Get(info) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return info, false, err @@ -79,7 +79,7 @@ func (nr *notificationRepo) GetById(ctx context.Context, id string) (*entity.Not func (nr *notificationRepo) GetByUserIdObjectIdTypeId(ctx context.Context, userID, objectID string, notificationType int) (*entity.Notification, bool, error) { info := &entity.Notification{} - exist, err := nr.data.DB.Where("user_id = ? ", userID).And("object_id = ?", objectID).And("type = ?", notificationType).Get(info) + exist, err := nr.data.DB.Context(ctx).Where("user_id = ? ", userID).And("object_id = ?", objectID).And("type = ?", notificationType).Get(info) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return info, false, err @@ -94,7 +94,7 @@ func (nr *notificationRepo) GetNotificationPage(ctx context.Context, searchCond return notificationList, 0, nil } - session := nr.data.DB.NewSession() + session := nr.data.DB.Context(ctx) session = session.Desc("updated_at") cond := &entity.Notification{ UserID: searchCond.UserID, diff --git a/internal/repo/plugin_config/plugin_config_repo.go b/internal/repo/plugin_config/plugin_config_repo.go index d615577f..aad97657 100644 --- a/internal/repo/plugin_config/plugin_config_repo.go +++ b/internal/repo/plugin_config/plugin_config_repo.go @@ -23,15 +23,15 @@ func NewPluginConfigRepo(data *data.Data) plugin_common.PluginConfigRepo { func (ur *pluginConfigRepo) SavePluginConfig(ctx context.Context, pluginSlugName, configValue string) (err error) { old := &entity.PluginConfig{PluginSlugName: pluginSlugName} - exist, err := ur.data.DB.Get(old) + exist, err := ur.data.DB.Context(ctx).Get(old) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } if exist { old.Value = configValue - _, err = ur.data.DB.ID(old.ID).Update(old) + _, err = ur.data.DB.Context(ctx).ID(old.ID).Update(old) } else { - _, err = ur.data.DB.InsertOne(&entity.PluginConfig{PluginSlugName: pluginSlugName, Value: configValue}) + _, err = ur.data.DB.Context(ctx).InsertOne(&entity.PluginConfig{PluginSlugName: pluginSlugName, Value: configValue}) } if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -41,7 +41,7 @@ func (ur *pluginConfigRepo) SavePluginConfig(ctx context.Context, pluginSlugName func (ur *pluginConfigRepo) GetPluginConfigAll(ctx context.Context) (pluginConfigs []*entity.PluginConfig, err error) { pluginConfigs = make([]*entity.PluginConfig, 0) - err = ur.data.DB.Find(&pluginConfigs) + err = ur.data.DB.Context(ctx).Find(&pluginConfigs) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/provider.go b/internal/repo/provider.go index b0dc6f8a..10722cec 100644 --- a/internal/repo/provider.go +++ b/internal/repo/provider.go @@ -9,7 +9,6 @@ import ( "github.com/answerdev/answer/internal/repo/captcha" "github.com/answerdev/answer/internal/repo/collection" "github.com/answerdev/answer/internal/repo/comment" - "github.com/answerdev/answer/internal/repo/common" "github.com/answerdev/answer/internal/repo/config" "github.com/answerdev/answer/internal/repo/export" "github.com/answerdev/answer/internal/repo/meta" @@ -33,7 +32,6 @@ import ( // ProviderSetRepo is data providers. var ProviderSetRepo = wire.NewSet( - common.NewCommonRepo, data.NewData, data.NewDB, data.NewCache, diff --git a/internal/repo/question/question_repo.go b/internal/repo/question/question_repo.go index 2ee02fc4..54e58c80 100644 --- a/internal/repo/question/question_repo.go +++ b/internal/repo/question/question_repo.go @@ -46,7 +46,7 @@ func (qr *questionRepo) AddQuestion(ctx context.Context, question *entity.Questi if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } - _, err = qr.data.DB.Insert(question) + _, err = qr.data.DB.Context(ctx).Insert(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -57,7 +57,7 @@ func (qr *questionRepo) AddQuestion(ctx context.Context, question *entity.Questi // RemoveQuestion delete question func (qr *questionRepo) RemoveQuestion(ctx context.Context, id string) (err error) { id = uid.DeShortID(id) - _, err = qr.data.DB.Where("id =?", id).Delete(&entity.Question{}) + _, err = qr.data.DB.Context(ctx).Where("id =?", id).Delete(&entity.Question{}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -67,7 +67,7 @@ func (qr *questionRepo) RemoveQuestion(ctx context.Context, id string) (err erro // UpdateQuestion update question func (qr *questionRepo) UpdateQuestion(ctx context.Context, question *entity.Question, Cols []string) (err error) { question.ID = uid.DeShortID(question.ID) - _, err = qr.data.DB.Where("id =?", question.ID).Cols(Cols...).Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols(Cols...).Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -78,7 +78,7 @@ func (qr *questionRepo) UpdateQuestion(ctx context.Context, question *entity.Que func (qr *questionRepo) UpdatePvCount(ctx context.Context, questionID string) (err error) { questionID = uid.DeShortID(questionID) question := &entity.Question{} - _, err = qr.data.DB.Where("id =?", questionID).Incr("view_count", 1).Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", questionID).Incr("view_count", 1).Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -88,7 +88,7 @@ func (qr *questionRepo) UpdatePvCount(ctx context.Context, questionID string) (e func (qr *questionRepo) UpdateAnswerCount(ctx context.Context, questionID string, num int) (err error) { questionID = uid.DeShortID(questionID) question := &entity.Question{} - _, err = qr.data.DB.Where("id =?", questionID).Incr("answer_count", num).Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", questionID).Incr("answer_count", num).Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -98,7 +98,7 @@ func (qr *questionRepo) UpdateAnswerCount(ctx context.Context, questionID string func (qr *questionRepo) UpdateCollectionCount(ctx context.Context, questionID string, num int) (err error) { questionID = uid.DeShortID(questionID) question := &entity.Question{} - _, err = qr.data.DB.Where("id =?", questionID).Incr("collection_count", num).Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", questionID).Incr("collection_count", num).Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -109,7 +109,7 @@ func (qr *questionRepo) UpdateQuestionStatus(ctx context.Context, question *enti question.ID = uid.DeShortID(question.ID) now := time.Now() question.UpdatedAt = now - _, err = qr.data.DB.Where("id =?", question.ID).Cols("status", "updated_at").Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols("status", "updated_at").Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -118,7 +118,7 @@ func (qr *questionRepo) UpdateQuestionStatus(ctx context.Context, question *enti func (qr *questionRepo) UpdateQuestionStatusWithOutUpdateTime(ctx context.Context, question *entity.Question) (err error) { question.ID = uid.DeShortID(question.ID) - _, err = qr.data.DB.Where("id =?", question.ID).Cols("status").Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols("status").Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -127,7 +127,7 @@ func (qr *questionRepo) UpdateQuestionStatusWithOutUpdateTime(ctx context.Contex func (qr *questionRepo) UpdateQuestionOperation(ctx context.Context, question *entity.Question) (err error) { question.ID = uid.DeShortID(question.ID) - _, err = qr.data.DB.Where("id =?", question.ID).Cols("pin", "show").Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols("pin", "show").Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -136,7 +136,7 @@ func (qr *questionRepo) UpdateQuestionOperation(ctx context.Context, question *e func (qr *questionRepo) UpdateAccepted(ctx context.Context, question *entity.Question) (err error) { question.ID = uid.DeShortID(question.ID) - _, err = qr.data.DB.Where("id =?", question.ID).Cols("accepted_answer_id").Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols("accepted_answer_id").Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -145,7 +145,7 @@ func (qr *questionRepo) UpdateAccepted(ctx context.Context, question *entity.Que func (qr *questionRepo) UpdateLastAnswer(ctx context.Context, question *entity.Question) (err error) { question.ID = uid.DeShortID(question.ID) - _, err = qr.data.DB.Where("id =?", question.ID).Cols("last_answer_id").Update(question) + _, err = qr.data.DB.Context(ctx).Where("id =?", question.ID).Cols("last_answer_id").Update(question) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -159,7 +159,7 @@ func (qr *questionRepo) GetQuestion(ctx context.Context, id string) ( id = uid.DeShortID(id) question = &entity.Question{} question.ID = id - exist, err = qr.data.DB.Where("id = ?", id).Get(question) + exist, err = qr.data.DB.Context(ctx).Where("id = ?", id).Get(question) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -170,7 +170,7 @@ func (qr *questionRepo) GetQuestion(ctx context.Context, id string) ( // GetTagBySlugName get tag by slug name func (qr *questionRepo) SearchByTitleLike(ctx context.Context, title string) (questionList []*entity.Question, err error) { questionList = make([]*entity.Question, 0) - err = qr.data.DB.Table("question").Where("title like ?", "%"+title+"%").Limit(10, 0).Find(&questionList) + err = qr.data.DB.Context(ctx).Table("question").Where("title like ?", "%"+title+"%").Limit(10, 0).Find(&questionList) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -185,7 +185,7 @@ func (qr *questionRepo) FindByID(ctx context.Context, id []string) (questionList id[key] = uid.DeShortID(itemID) } questionList = make([]*entity.Question, 0) - err = qr.data.DB.Table("question").In("id", id).Find(&questionList) + err = qr.data.DB.Context(ctx).Table("question").In("id", id).Find(&questionList) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -199,7 +199,7 @@ func (qr *questionRepo) FindByID(ctx context.Context, id []string) (questionList func (qr *questionRepo) GetQuestionList(ctx context.Context, question *entity.Question) (questionList []*entity.Question, err error) { question.ID = uid.DeShortID(question.ID) questionList = make([]*entity.Question, 0) - err = qr.data.DB.Find(questionList, question) + err = qr.data.DB.Context(ctx).Find(questionList, question) if err != nil { return questionList, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -212,7 +212,7 @@ func (qr *questionRepo) GetQuestionList(ctx context.Context, question *entity.Qu func (qr *questionRepo) GetQuestionCount(ctx context.Context) (count int64, err error) { questionList := make([]*entity.Question, 0) - count, err = qr.data.DB.In("question.status", []int{entity.QuestionStatusAvailable, entity.QuestionStatusClosed}).FindAndCount(&questionList) + count, err = qr.data.DB.Context(ctx).In("question.status", []int{entity.QuestionStatusAvailable, entity.QuestionStatusClosed}).FindAndCount(&questionList) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -231,7 +231,7 @@ func (qr *questionRepo) GetQuestionIDsPage(ctx context.Context, page, pageSize i pageSize = constant.DefaultPageSize } offset := page * pageSize - session := qr.data.DB.Table("question") + session := qr.data.DB.Context(ctx).Table("question") session = session.In("question.status", []int{entity.QuestionStatusAvailable, entity.QuestionStatusClosed}) session.And("question.show = ?", entity.QuestionShow) session = session.Limit(pageSize, offset) @@ -259,7 +259,7 @@ func (qr *questionRepo) GetQuestionPage(ctx context.Context, page, pageSize int, questionList []*entity.Question, total int64, err error) { questionList = make([]*entity.Question, 0) - session := qr.data.DB.Where("question.status = ? OR question.status = ?", + session := qr.data.DB.Context(ctx).Where("question.status = ? OR question.status = ?", entity.QuestionStatusAvailable, entity.QuestionStatusClosed) if len(tagID) > 0 { session.Join("LEFT", "tag_rel", "question.id = tag_rel.object_id") @@ -303,7 +303,7 @@ func (qr *questionRepo) AdminSearchList(ctx context.Context, search *schema.Admi var ( count int64 err error - session = qr.data.DB.Table("question") + session = qr.data.DB.Context(ctx).Table("question") ) session.Where(builder.Eq{ diff --git a/internal/repo/rank/user_rank_repo.go b/internal/repo/rank/user_rank_repo.go index 462f999e..f25bf819 100644 --- a/internal/repo/rank/user_rank_repo.go +++ b/internal/repo/rank/user_rank_repo.go @@ -138,7 +138,7 @@ func (ur *UserRankRepo) UserRankPage(ctx context.Context, userID string, page, p ) { rankPage = make([]*entity.Activity, 0) - session := ur.data.DB.Where(builder.Eq{"has_rank": 1}.And(builder.Eq{"cancelled": 0})) + session := ur.data.DB.Context(ctx).Where(builder.Eq{"has_rank": 1}.And(builder.Eq{"cancelled": 0})) session.Desc("created_at") cond := &entity.Activity{UserID: userID} diff --git a/internal/repo/report/report_repo.go b/internal/repo/report/report_repo.go index 3fa627f0..21f33cdf 100644 --- a/internal/repo/report/report_repo.go +++ b/internal/repo/report/report_repo.go @@ -35,7 +35,7 @@ func (rr *reportRepo) AddReport(ctx context.Context, report *entity.Report) (err if err != nil { return err } - _, err = rr.data.DB.Insert(report) + _, err = rr.data.DB.Context(ctx).Insert(report) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -48,7 +48,7 @@ func (rr *reportRepo) GetReportListPage(ctx context.Context, dto schema.GetRepor ok bool status int objectType int - session = rr.data.DB.NewSession() + session = rr.data.DB.Context(ctx) cond = entity.Report{} ) @@ -78,7 +78,7 @@ func (rr *reportRepo) GetReportListPage(ctx context.Context, dto schema.GetRepor // GetByID get report by ID func (rr *reportRepo) GetByID(ctx context.Context, id string) (report *entity.Report, exist bool, err error) { report = &entity.Report{} - exist, err = rr.data.DB.ID(id).Get(report) + exist, err = rr.data.DB.Context(ctx).ID(id).Get(report) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -87,7 +87,7 @@ func (rr *reportRepo) GetByID(ctx context.Context, id string) (report *entity.Re // UpdateByID handle report by ID func (rr *reportRepo) UpdateByID(ctx context.Context, id string, handleData entity.Report) (err error) { - _, err = rr.data.DB.ID(id).Update(&handleData) + _, err = rr.data.DB.Context(ctx).ID(id).Update(&handleData) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -96,7 +96,7 @@ func (rr *reportRepo) UpdateByID(ctx context.Context, id string, handleData enti func (rr *reportRepo) GetReportCount(ctx context.Context) (count int64, err error) { list := make([]*entity.Report, 0) - count, err = rr.data.DB.Where("status =?", entity.ReportStatusPending).FindAndCount(&list) + count, err = rr.data.DB.Context(ctx).Where("status =?", entity.ReportStatusPending).FindAndCount(&list) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/revision/revision_repo.go b/internal/repo/revision/revision_repo.go index 813be967..314276af 100644 --- a/internal/repo/revision/revision_repo.go +++ b/internal/repo/revision/revision_repo.go @@ -46,6 +46,7 @@ func (rr *revisionRepo) AddRevision(ctx context.Context, revision *entity.Revisi return nil } _, err = rr.data.DB.Transaction(func(session *xorm.Session) (interface{}, error) { + session = session.Context(ctx) _, err = session.Insert(revision) if err != nil { _ = session.Rollback() @@ -90,7 +91,7 @@ func (rr *revisionRepo) UpdateStatus(ctx context.Context, id string, status int, data.ID = id data.Status = status data.ReviewUserID = converter.StringToInt64(reviewUserID) - _, err = rr.data.DB.Where("id =?", id).Cols("status", "review_user_id").Update(&data) + _, err = rr.data.DB.Context(ctx).Where("id =?", id).Cols("status", "review_user_id").Update(&data) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -102,7 +103,7 @@ func (rr *revisionRepo) GetRevision(ctx context.Context, id string) ( revision *entity.Revision, exist bool, err error, ) { revision = &entity.Revision{} - exist, err = rr.data.DB.ID(id).Get(revision) + exist, err = rr.data.DB.Context(ctx).ID(id).Get(revision) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -113,7 +114,7 @@ func (rr *revisionRepo) GetRevision(ctx context.Context, id string) ( func (rr *revisionRepo) GetRevisionByID(ctx context.Context, revisionID string) ( revision *entity.Revision, exist bool, err error) { revision = &entity.Revision{} - exist, err = rr.data.DB.Where("id = ?", revisionID).Get(revision) + exist, err = rr.data.DB.Context(ctx).Where("id = ?", revisionID).Get(revision) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -123,7 +124,7 @@ func (rr *revisionRepo) GetRevisionByID(ctx context.Context, revisionID string) func (rr *revisionRepo) ExistUnreviewedByObjectID(ctx context.Context, objectID string) ( revision *entity.Revision, exist bool, err error) { revision = &entity.Revision{} - exist, err = rr.data.DB.Where("object_id = ?", objectID).And("status = ?", entity.RevisionUnreviewedStatus).Get(revision) + exist, err = rr.data.DB.Context(ctx).Where("object_id = ?", objectID).And("status = ?", entity.RevisionUnreviewedStatus).Get(revision) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -135,7 +136,7 @@ func (rr *revisionRepo) GetLastRevisionByObjectID(ctx context.Context, objectID revision *entity.Revision, exist bool, err error, ) { revision = &entity.Revision{} - exist, err = rr.data.DB.Where("object_id = ?", objectID).OrderBy("created_at DESC").Get(revision) + exist, err = rr.data.DB.Context(ctx).Where("object_id = ?", objectID).OrderBy("created_at DESC").Get(revision) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -145,7 +146,7 @@ func (rr *revisionRepo) GetLastRevisionByObjectID(ctx context.Context, objectID // GetRevisionList get revision list all func (rr *revisionRepo) GetRevisionList(ctx context.Context, revision *entity.Revision) (revisionList []entity.Revision, err error) { revisionList = []entity.Revision{} - err = rr.data.DB.Where(builder.Eq{ + err = rr.data.DB.Context(ctx).Where(builder.Eq{ "object_id": revision.ObjectID, }).OrderBy("created_at DESC").Find(&revisionList) if err != nil { @@ -175,7 +176,7 @@ func (rr *revisionRepo) GetUnreviewedRevisionPage(ctx context.Context, page int, if len(objectTypeList) == 0 { return revisionList, 0, nil } - session := rr.data.DB.NewSession() + session := rr.data.DB.Context(ctx) session = session.And("status = ?", entity.RevisionUnreviewedStatus) session = session.In("object_type", objectTypeList) session = session.OrderBy("created_at asc") diff --git a/internal/repo/role/power_repo.go b/internal/repo/role/power_repo.go index 1c902e6c..c2a5a402 100644 --- a/internal/repo/role/power_repo.go +++ b/internal/repo/role/power_repo.go @@ -25,7 +25,7 @@ func NewPowerRepo(data *data.Data) role.PowerRepo { // GetPowerList get list all func (pr *powerRepo) GetPowerList(ctx context.Context, power *entity.Power) (powerList []*entity.Power, err error) { powerList = make([]*entity.Power, 0) - err = pr.data.DB.Find(powerList, power) + err = pr.data.DB.Context(ctx).Find(powerList, power) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/role/role_power_rel_repo.go b/internal/repo/role/role_power_rel_repo.go index dc553eaa..73761c98 100644 --- a/internal/repo/role/role_power_rel_repo.go +++ b/internal/repo/role/role_power_rel_repo.go @@ -25,7 +25,7 @@ func NewRolePowerRelRepo(data *data.Data) role.RolePowerRelRepo { // GetRolePowerTypeList get role power type list func (rr *rolePowerRelRepo) GetRolePowerTypeList(ctx context.Context, roleID int) (powers []string, err error) { powers = make([]string, 0) - err = rr.data.DB.Table("role_power_rel"). + err = rr.data.DB.Context(ctx).Table("role_power_rel"). Cols("power_type").Where(builder.Eq{"role_id": roleID}).Find(&powers) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() diff --git a/internal/repo/role/role_repo.go b/internal/repo/role/role_repo.go index 4534795e..6ee42df1 100644 --- a/internal/repo/role/role_repo.go +++ b/internal/repo/role/role_repo.go @@ -25,7 +25,7 @@ func NewRoleRepo(data *data.Data) service.RoleRepo { // GetRoleAllList get role list all func (rr *roleRepo) GetRoleAllList(ctx context.Context) (roleList []*entity.Role, err error) { roleList = make([]*entity.Role, 0) - err = rr.data.DB.Find(&roleList) + err = rr.data.DB.Context(ctx).Find(&roleList) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/role/user_role_rel_repo.go b/internal/repo/role/user_role_rel_repo.go index 9398a5d9..b840e86e 100644 --- a/internal/repo/role/user_role_rel_repo.go +++ b/internal/repo/role/user_role_rel_repo.go @@ -27,6 +27,7 @@ func NewUserRoleRelRepo(data *data.Data) role.UserRoleRelRepo { // SaveUserRoleRel save user role rel func (ur *userRoleRelRepo) SaveUserRoleRel(ctx context.Context, userID string, roleID int) (err error) { _, err = ur.data.DB.Transaction(func(session *xorm.Session) (interface{}, error) { + session = session.Context(ctx) item := &entity.UserRoleRel{UserID: userID} exist, err := session.Get(item) if err != nil { @@ -53,7 +54,7 @@ func (ur *userRoleRelRepo) SaveUserRoleRel(ctx context.Context, userID string, r func (ur *userRoleRelRepo) GetUserRoleRelList(ctx context.Context, userIDs []string) ( userRoleRelList []*entity.UserRoleRel, err error) { userRoleRelList = make([]*entity.UserRoleRel, 0) - err = ur.data.DB.In("user_id", userIDs).Find(&userRoleRelList) + err = ur.data.DB.Context(ctx).In("user_id", userIDs).Find(&userRoleRelList) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -64,7 +65,7 @@ func (ur *userRoleRelRepo) GetUserRoleRelList(ctx context.Context, userIDs []str func (ur *userRoleRelRepo) GetUserRoleRelListByRoleID(ctx context.Context, roleIDs []int) ( userRoleRelList []*entity.UserRoleRel, err error) { userRoleRelList = make([]*entity.UserRoleRel, 0) - err = ur.data.DB.In("role_id", roleIDs).Find(&userRoleRelList) + err = ur.data.DB.Context(ctx).In("role_id", roleIDs).Find(&userRoleRelList) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -75,7 +76,7 @@ func (ur *userRoleRelRepo) GetUserRoleRelListByRoleID(ctx context.Context, roleI func (ur *userRoleRelRepo) GetUserRoleRel(ctx context.Context, userID string) ( rolePowerRel *entity.UserRoleRel, exist bool, err error) { rolePowerRel = &entity.UserRoleRel{} - exist, err = ur.data.DB.Where(builder.Eq{"user_id": userID}).Get(rolePowerRel) + exist, err = ur.data.DB.Context(ctx).Where(builder.Eq{"user_id": userID}).Get(rolePowerRel) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/search_common/search_repo.go b/internal/repo/search_common/search_repo.go index da32503c..1f46bdf4 100644 --- a/internal/repo/search_common/search_repo.go +++ b/internal/repo/search_common/search_repo.go @@ -190,12 +190,12 @@ func (sr *searchRepo) SearchContents(ctx context.Context, words []string, tagIDs countArgs = append(countArgs, argsQ...) countArgs = append(countArgs, argsA...) - res, err := sr.data.DB.Query(queryArgs...) + res, err := sr.data.DB.Context(ctx).Query(queryArgs...) if err != nil { return } - tr, err := sr.data.DB.Query(countArgs...) + tr, err := sr.data.DB.Context(ctx).Query(countArgs...) if len(tr) != 0 { total = converter.StringToInt64(string(tr[0]["total"])) } @@ -297,12 +297,12 @@ func (sr *searchRepo) SearchQuestions(ctx context.Context, words []string, tagID countArgs = append(countArgs, countSQL) countArgs = append(countArgs, args...) - res, err := sr.data.DB.Query(queryArgs...) + res, err := sr.data.DB.Context(ctx).Query(queryArgs...) if err != nil { return } - tr, err := sr.data.DB.Query(countArgs...) + tr, err := sr.data.DB.Context(ctx).Query(countArgs...) if err != nil { return } @@ -392,12 +392,12 @@ func (sr *searchRepo) SearchAnswers(ctx context.Context, words []string, tagIDs countArgs = append(countArgs, countSQL) countArgs = append(countArgs, args...) - res, err := sr.data.DB.Query(queryArgs...) + res, err := sr.data.DB.Context(ctx).Query(queryArgs...) if err != nil { return } - tr, err := sr.data.DB.Query(countArgs...) + tr, err := sr.data.DB.Context(ctx).Query(countArgs...) if err != nil { return } @@ -451,7 +451,7 @@ func (sr *searchRepo) parseResult(ctx context.Context, res []map[string][]byte) } // get tags - err = sr.data.DB. + err = sr.data.DB.Context(ctx). Select("`display_name`,`slug_name`,`main_tag_slug_name`,`recommend`,`reserved`"). Table("tag"). Join("INNER", "tag_rel", "tag.id = tag_rel.tag_id"). diff --git a/internal/repo/site_info/siteinfo_repo.go b/internal/repo/site_info/siteinfo_repo.go index bee8ae50..03c89416 100644 --- a/internal/repo/site_info/siteinfo_repo.go +++ b/internal/repo/site_info/siteinfo_repo.go @@ -27,14 +27,14 @@ func NewSiteInfo(data *data.Data) siteinfo_common.SiteInfoRepo { // SaveByType save site setting by type func (sr *siteInfoRepo) SaveByType(ctx context.Context, siteType string, data *entity.SiteInfo) (err error) { old := &entity.SiteInfo{} - exist, err := sr.data.DB.Where(builder.Eq{"type": siteType}).Get(old) + exist, err := sr.data.DB.Context(ctx).Where(builder.Eq{"type": siteType}).Get(old) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } if exist { - _, err = sr.data.DB.ID(old.ID).Update(data) + _, err = sr.data.DB.Context(ctx).ID(old.ID).Update(data) } else { - _, err = sr.data.DB.Insert(data) + _, err = sr.data.DB.Context(ctx).Insert(data) } if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -50,7 +50,7 @@ func (sr *siteInfoRepo) GetByType(ctx context.Context, siteType string) (siteInf return siteInfo, true, nil } siteInfo = &entity.SiteInfo{} - exist, err = sr.data.DB.Where(builder.Eq{"type": siteType}).Get(siteInfo) + exist, err = sr.data.DB.Context(ctx).Where(builder.Eq{"type": siteType}).Get(siteInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/tag/tag_rel_repo.go b/internal/repo/tag/tag_rel_repo.go index 22d26621..a015dee8 100644 --- a/internal/repo/tag/tag_rel_repo.go +++ b/internal/repo/tag/tag_rel_repo.go @@ -32,7 +32,7 @@ func (tr *tagRelRepo) AddTagRelList(ctx context.Context, tagList []*entity.TagRe for _, item := range tagList { item.ObjectID = uid.DeShortID(item.ObjectID) } - _, err = tr.data.DB.Insert(tagList) + _, err = tr.data.DB.Context(ctx).Insert(tagList) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -45,7 +45,7 @@ func (tr *tagRelRepo) AddTagRelList(ctx context.Context, tagList []*entity.TagRe // RemoveTagRelListByObjectID delete tag list func (tr *tagRelRepo) RemoveTagRelListByObjectID(ctx context.Context, objectID string) (err error) { objectID = uid.DeShortID(objectID) - _, err = tr.data.DB.Where("object_id = ?", objectID).Update(&entity.TagRel{Status: entity.TagRelStatusDeleted}) + _, err = tr.data.DB.Context(ctx).Where("object_id = ?", objectID).Update(&entity.TagRel{Status: entity.TagRelStatusDeleted}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -54,7 +54,7 @@ func (tr *tagRelRepo) RemoveTagRelListByObjectID(ctx context.Context, objectID s // RemoveTagRelListByIDs delete tag list func (tr *tagRelRepo) RemoveTagRelListByIDs(ctx context.Context, ids []int64) (err error) { - _, err = tr.data.DB.In("id", ids).Update(&entity.TagRel{Status: entity.TagRelStatusDeleted}) + _, err = tr.data.DB.Context(ctx).In("id", ids).Update(&entity.TagRel{Status: entity.TagRelStatusDeleted}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -67,7 +67,7 @@ func (tr *tagRelRepo) GetObjectTagRelWithoutStatus(ctx context.Context, objectID ) { objectID = uid.DeShortID(objectID) tagRel = &entity.TagRel{} - session := tr.data.DB.Where("object_id = ?", objectID).And("tag_id = ?", tagID) + session := tr.data.DB.Context(ctx).Where("object_id = ?", objectID).And("tag_id = ?", tagID) exist, err = session.Get(tagRel) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -78,7 +78,7 @@ func (tr *tagRelRepo) GetObjectTagRelWithoutStatus(ctx context.Context, objectID // EnableTagRelByIDs update tag status to available func (tr *tagRelRepo) EnableTagRelByIDs(ctx context.Context, ids []int64) (err error) { - _, err = tr.data.DB.In("id", ids).Update(&entity.TagRel{Status: entity.TagRelStatusAvailable}) + _, err = tr.data.DB.Context(ctx).In("id", ids).Update(&entity.TagRel{Status: entity.TagRelStatusAvailable}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -89,7 +89,7 @@ func (tr *tagRelRepo) EnableTagRelByIDs(ctx context.Context, ids []int64) (err e func (tr *tagRelRepo) GetObjectTagRelList(ctx context.Context, objectID string) (tagListList []*entity.TagRel, err error) { objectID = uid.DeShortID(objectID) tagListList = make([]*entity.TagRel, 0) - session := tr.data.DB.Where("object_id = ?", objectID) + session := tr.data.DB.Context(ctx).Where("object_id = ?", objectID) session.Where("status = ?", entity.TagRelStatusAvailable) err = session.Find(&tagListList) if err != nil { @@ -107,7 +107,7 @@ func (tr *tagRelRepo) BatchGetObjectTagRelList(ctx context.Context, objectIds [] objectIds[num] = uid.DeShortID(item) } tagListList = make([]*entity.TagRel, 0) - session := tr.data.DB.In("object_id", objectIds) + session := tr.data.DB.Context(ctx).In("object_id", objectIds) session.Where("status = ?", entity.TagRelStatusAvailable) err = session.Find(&tagListList) if err != nil { @@ -121,7 +121,7 @@ func (tr *tagRelRepo) BatchGetObjectTagRelList(ctx context.Context, objectIds [] // CountTagRelByTagID count tag relation func (tr *tagRelRepo) CountTagRelByTagID(ctx context.Context, tagID string) (count int64, err error) { - count, err = tr.data.DB.Count(&entity.TagRel{TagID: tagID, Status: entity.AnswerStatusAvailable}) + count, err = tr.data.DB.Context(ctx).Count(&entity.TagRel{TagID: tagID, Status: entity.AnswerStatusAvailable}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/tag/tag_repo.go b/internal/repo/tag/tag_repo.go index 02604cb7..d5ea992b 100644 --- a/internal/repo/tag/tag_repo.go +++ b/internal/repo/tag/tag_repo.go @@ -32,7 +32,7 @@ func NewTagRepo( // RemoveTag delete tag func (tr *tagRepo) RemoveTag(ctx context.Context, tagID string) (err error) { - session := tr.data.DB.Where(builder.Eq{"id": tagID}) + session := tr.data.DB.Context(ctx).Where(builder.Eq{"id": tagID}) _, err = session.Update(&entity.Tag{Status: entity.TagStatusDeleted}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -42,7 +42,7 @@ func (tr *tagRepo) RemoveTag(ctx context.Context, tagID string) (err error) { // UpdateTag update tag func (tr *tagRepo) UpdateTag(ctx context.Context, tag *entity.Tag) (err error) { - _, err = tr.data.DB.Where(builder.Eq{"id": tag.ID}).Update(tag) + _, err = tr.data.DB.Context(ctx).Where(builder.Eq{"id": tag.ID}).Update(tag) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -54,7 +54,7 @@ func (tr *tagRepo) UpdateTagSynonym(ctx context.Context, tagSlugNameList []strin mainTagSlugName string, ) (err error) { bean := &entity.Tag{MainTagID: mainTagID, MainTagSlugName: mainTagSlugName} - session := tr.data.DB.In("slug_name", tagSlugNameList).MustCols("main_tag_id", "main_tag_slug_name") + session := tr.data.DB.Context(ctx).In("slug_name", tagSlugNameList).MustCols("main_tag_id", "main_tag_slug_name") _, err = session.Update(bean) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -63,7 +63,7 @@ func (tr *tagRepo) UpdateTagSynonym(ctx context.Context, tagSlugNameList []strin } func (tr *tagRepo) GetTagSynonymCount(ctx context.Context, tagID string) (count int64, err error) { - count, err = tr.data.DB.Count(&entity.Tag{MainTagID: converter.StringToInt64(tagID), Status: entity.TagStatusAvailable}) + count, err = tr.data.DB.Context(ctx).Count(&entity.Tag{MainTagID: converter.StringToInt64(tagID), Status: entity.TagStatusAvailable}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -73,7 +73,7 @@ func (tr *tagRepo) GetTagSynonymCount(ctx context.Context, tagID string) (count // GetTagList get tag list all func (tr *tagRepo) GetTagList(ctx context.Context, tag *entity.Tag) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) - session := tr.data.DB.Where(builder.Eq{"status": entity.TagStatusAvailable}) + session := tr.data.DB.Context(ctx).Where(builder.Eq{"status": entity.TagStatusAvailable}) err = session.Find(&tagList, tag) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() diff --git a/internal/repo/tag_common/tag_common_repo.go b/internal/repo/tag_common/tag_common_repo.go index b12810e0..a7380028 100644 --- a/internal/repo/tag_common/tag_common_repo.go +++ b/internal/repo/tag_common/tag_common_repo.go @@ -34,7 +34,7 @@ func NewTagCommonRepo( // GetTagListByIDs get tag list all func (tr *tagCommonRepo) GetTagListByIDs(ctx context.Context, ids []string) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) - session := tr.data.DB.In("id", ids) + session := tr.data.DB.Context(ctx).In("id", ids) session.Where(builder.Eq{"status": entity.TagStatusAvailable}) err = session.OrderBy("recommend desc,reserved desc,id desc").Find(&tagList) if err != nil { @@ -46,7 +46,7 @@ func (tr *tagCommonRepo) GetTagListByIDs(ctx context.Context, ids []string) (tag // GetTagBySlugName get tag by slug name func (tr *tagCommonRepo) GetTagBySlugName(ctx context.Context, slugName string) (tagInfo *entity.Tag, exist bool, err error) { tagInfo = &entity.Tag{} - session := tr.data.DB.Where("LOWER(slug_name) = ?", slugName) + session := tr.data.DB.Context(ctx).Where("LOWER(slug_name) = ?", slugName) session.Where(builder.Eq{"status": entity.TagStatusAvailable}) exist, err = session.Get(tagInfo) if err != nil { @@ -59,7 +59,7 @@ func (tr *tagCommonRepo) GetTagBySlugName(ctx context.Context, slugName string) func (tr *tagCommonRepo) GetTagListByName(ctx context.Context, name string, hasReserved bool) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) cond := &entity.Tag{} - session := tr.data.DB.Where("") + session := tr.data.DB.Context(ctx).Where("") if name != "" { session.Where("slug_name LIKE LOWER(?) or display_name LIKE ?", name+"%", name+"%") } else { @@ -78,7 +78,7 @@ func (tr *tagCommonRepo) GetTagListByName(ctx context.Context, name string, hasR func (tr *tagCommonRepo) GetRecommendTagList(ctx context.Context) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) cond := &entity.Tag{} - session := tr.data.DB.Where("") + session := tr.data.DB.Context(ctx).Where("") cond.Recommend = true // session.Where(builder.Eq{"status": entity.TagStatusAvailable}) session.Asc("slug_name") @@ -93,7 +93,7 @@ func (tr *tagCommonRepo) GetRecommendTagList(ctx context.Context) (tagList []*en func (tr *tagCommonRepo) GetReservedTagList(ctx context.Context) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) cond := &entity.Tag{} - session := tr.data.DB.Where("") + session := tr.data.DB.Context(ctx).Where("") cond.Reserved = true // session.Where(builder.Eq{"status": entity.TagStatusAvailable}) session.Asc("slug_name") @@ -109,7 +109,7 @@ func (tr *tagCommonRepo) GetReservedTagList(ctx context.Context) (tagList []*ent func (tr *tagCommonRepo) GetTagListByNames(ctx context.Context, names []string) (tagList []*entity.Tag, err error) { tagList = make([]*entity.Tag, 0) - session := tr.data.DB.In("slug_name", names).UseBool("recommend", "reserved") + session := tr.data.DB.Context(ctx).In("slug_name", names).UseBool("recommend", "reserved") // session.Where(builder.Eq{"status": entity.TagStatusAvailable}) err = session.OrderBy("recommend desc,reserved desc,id desc").Find(&tagList) if err != nil { @@ -123,7 +123,7 @@ func (tr *tagCommonRepo) GetTagByID(ctx context.Context, tagID string, includeDe tag *entity.Tag, exist bool, err error, ) { tag = &entity.Tag{} - session := tr.data.DB.Where(builder.Eq{"id": tagID}) + session := tr.data.DB.Context(ctx).Where(builder.Eq{"id": tagID}) if !includeDeleted { session.Where(builder.Eq{"status": entity.TagStatusAvailable}) } @@ -139,7 +139,7 @@ func (tr *tagCommonRepo) GetTagPage(ctx context.Context, page, pageSize int, tag tagList []*entity.Tag, total int64, err error, ) { tagList = make([]*entity.Tag, 0) - session := tr.data.DB.NewSession() + session := tr.data.DB.Context(ctx) if len(tag.SlugName) > 0 { session.Where(builder.Or(builder.Like{"slug_name", fmt.Sprintf("LOWER(%s)", tag.SlugName)}, builder.Like{"display_name", tag.SlugName})) @@ -173,7 +173,7 @@ func (tr *tagCommonRepo) AddTagList(ctx context.Context, tagList []*entity.Tag) } item.RevisionID = "0" } - _, err = tr.data.DB.Insert(tagList) + _, err = tr.data.DB.Context(ctx).Insert(tagList) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -183,7 +183,7 @@ func (tr *tagCommonRepo) AddTagList(ctx context.Context, tagList []*entity.Tag) // UpdateTagQuestionCount update tag question count func (tr *tagCommonRepo) UpdateTagQuestionCount(ctx context.Context, tagID string, questionCount int) (err error) { cond := &entity.Tag{QuestionCount: questionCount} - _, err = tr.data.DB.Where(builder.Eq{"id": tagID}).MustCols("question_count").Update(cond) + _, err = tr.data.DB.Context(ctx).Where(builder.Eq{"id": tagID}).MustCols("question_count").Update(cond) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -200,7 +200,7 @@ func (tr *tagCommonRepo) UpdateTagsAttribute(ctx context.Context, tags []string, default: return } - session := tr.data.DB.In("slug_name", tags).Cols(attribute).UseBool(attribute) + session := tr.data.DB.Context(ctx).In("slug_name", tags).Cols(attribute).UseBool(attribute) _, err = session.Update(bean) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() diff --git a/internal/repo/unique/uniqid_repo.go b/internal/repo/unique/uniqid_repo.go index c460d904..eba2b527 100644 --- a/internal/repo/unique/uniqid_repo.go +++ b/internal/repo/unique/uniqid_repo.go @@ -29,7 +29,7 @@ func NewUniqueIDRepo(data *data.Data) unique.UniqueIDRepo { func (ur *uniqueIDRepo) GenUniqueIDStr(ctx context.Context, key string) (uniqueID string, err error) { objectType := constant.ObjectTypeStrMapping[key] bean := &entity.Uniqid{UniqidType: objectType} - _, err = ur.data.DB.Insert(bean) + _, err = ur.data.DB.Context(ctx).Insert(bean) if err != nil { return "", errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/repo/user/user_backyard_repo.go b/internal/repo/user/user_backyard_repo.go index cebb2112..1af0a41a 100644 --- a/internal/repo/user/user_backyard_repo.go +++ b/internal/repo/user/user_backyard_repo.go @@ -42,7 +42,7 @@ func (ur *userAdminRepo) UpdateUserStatus(ctx context.Context, userID string, us case entity.UserStatusDeleted: cond.DeletedAt = time.Now() } - _, err = ur.data.DB.ID(userID).Update(cond) + _, err = ur.data.DB.Context(ctx).ID(userID).Update(cond) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -63,7 +63,7 @@ func (ur *userAdminRepo) UpdateUserStatus(ctx context.Context, userID string, us // AddUser add user func (ur *userAdminRepo) AddUser(ctx context.Context, user *entity.User) (err error) { - _, err = ur.data.DB.Insert(user) + _, err = ur.data.DB.Context(ctx).Insert(user) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -72,7 +72,7 @@ func (ur *userAdminRepo) AddUser(ctx context.Context, user *entity.User) (err er // UpdateUserPassword update user password func (ur *userAdminRepo) UpdateUserPassword(ctx context.Context, userID string, password string) (err error) { - _, err = ur.data.DB.ID(userID).Update(&entity.User{Pass: password}) + _, err = ur.data.DB.Context(ctx).ID(userID).Update(&entity.User{Pass: password}) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -82,7 +82,7 @@ func (ur *userAdminRepo) UpdateUserPassword(ctx context.Context, userID string, // GetUserInfo get user info func (ur *userAdminRepo) GetUserInfo(ctx context.Context, userID string) (user *entity.User, exist bool, err error) { user = &entity.User{} - exist, err = ur.data.DB.ID(userID).Get(user) + exist, err = ur.data.DB.Context(ctx).ID(userID).Get(user) if err != nil { return nil, false, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -99,7 +99,7 @@ func (ur *userAdminRepo) GetUserInfo(ctx context.Context, userID string) (user * // GetUserInfoByEmail get user info func (ur *userAdminRepo) GetUserInfoByEmail(ctx context.Context, email string) (user *entity.User, exist bool, err error) { userInfo := &entity.User{} - exist, err = ur.data.DB.Where("e_mail = ?", email). + exist, err = ur.data.DB.Context(ctx).Where("e_mail = ?", email). Where("status != ?", entity.UserStatusDeleted).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -119,7 +119,7 @@ func (ur *userAdminRepo) GetUserInfoByEmail(ctx context.Context, email string) ( func (ur *userAdminRepo) GetUserPage(ctx context.Context, page, pageSize int, user *entity.User, usernameOrDisplayName string, isStaff bool) (users []*entity.User, total int64, err error) { users = make([]*entity.User, 0) - session := ur.data.DB.NewSession() + session := ur.data.DB.Context(ctx) switch user.Status { case entity.UserStatusDeleted: session.Desc("user.deleted_at") diff --git a/internal/repo/user/user_repo.go b/internal/repo/user/user_repo.go index b8129f20..f0412250 100644 --- a/internal/repo/user/user_repo.go +++ b/internal/repo/user/user_repo.go @@ -34,6 +34,7 @@ func NewUserRepo(data *data.Data, configRepo config.ConfigRepo) usercommon.UserR // AddUser add user func (ur *userRepo) AddUser(ctx context.Context, user *entity.User) (err error) { _, err = ur.data.DB.Transaction(func(session *xorm.Session) (interface{}, error) { + session = session.Context(ctx) userInfo := &entity.User{} exist, err := session.Where("username = ?", user.Username).Get(userInfo) if err != nil { @@ -54,7 +55,7 @@ func (ur *userRepo) AddUser(ctx context.Context, user *entity.User) (err error) // IncreaseAnswerCount increase answer count func (ur *userRepo) IncreaseAnswerCount(ctx context.Context, userID string, amount int) (err error) { user := &entity.User{} - _, err = ur.data.DB.Where("id = ?", userID).Incr("answer_count", amount).Update(user) + _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Incr("answer_count", amount).Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -64,7 +65,7 @@ func (ur *userRepo) IncreaseAnswerCount(ctx context.Context, userID string, amou // IncreaseQuestionCount increase question count func (ur *userRepo) IncreaseQuestionCount(ctx context.Context, userID string, amount int) (err error) { user := &entity.User{} - _, err = ur.data.DB.Where("id = ?", userID).Incr("question_count", amount).Update(user) + _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Incr("question_count", amount).Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -74,7 +75,7 @@ func (ur *userRepo) IncreaseQuestionCount(ctx context.Context, userID string, am // UpdateLastLoginDate update last login date func (ur *userRepo) UpdateLastLoginDate(ctx context.Context, userID string) (err error) { user := &entity.User{LastLoginDate: time.Now()} - _, err = ur.data.DB.Where("id = ?", userID).Cols("last_login_date").Update(user) + _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("last_login_date").Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -84,7 +85,7 @@ func (ur *userRepo) UpdateLastLoginDate(ctx context.Context, userID string) (err // UpdateEmailStatus update email status func (ur *userRepo) UpdateEmailStatus(ctx context.Context, userID string, emailStatus int) error { cond := &entity.User{MailStatus: emailStatus} - _, err := ur.data.DB.Where("id = ?", userID).Cols("mail_status").Update(cond) + _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("mail_status").Update(cond) if err != nil { return err } @@ -94,7 +95,7 @@ func (ur *userRepo) UpdateEmailStatus(ctx context.Context, userID string, emailS // UpdateNoticeStatus update notice status func (ur *userRepo) UpdateNoticeStatus(ctx context.Context, userID string, noticeStatus int) error { cond := &entity.User{NoticeStatus: noticeStatus} - _, err := ur.data.DB.Where("id = ?", userID).Cols("notice_status").Update(cond) + _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("notice_status").Update(cond) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -102,7 +103,7 @@ func (ur *userRepo) UpdateNoticeStatus(ctx context.Context, userID string, notic } func (ur *userRepo) UpdatePass(ctx context.Context, userID, pass string) error { - _, err := ur.data.DB.Where("id = ?", userID).Cols("pass").Update(&entity.User{Pass: pass}) + _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("pass").Update(&entity.User{Pass: pass}) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -110,7 +111,7 @@ func (ur *userRepo) UpdatePass(ctx context.Context, userID, pass string) error { } func (ur *userRepo) UpdateEmail(ctx context.Context, userID, email string) (err error) { - _, err = ur.data.DB.Where("id = ?", userID).Update(&entity.User{EMail: email}) + _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Update(&entity.User{EMail: email}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -118,7 +119,7 @@ func (ur *userRepo) UpdateEmail(ctx context.Context, userID, email string) (err } func (ur *userRepo) UpdateLanguage(ctx context.Context, userID, language string) (err error) { - _, err = ur.data.DB.Where("id = ?", userID).Update(&entity.User{Language: language}) + _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Update(&entity.User{Language: language}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -127,7 +128,7 @@ func (ur *userRepo) UpdateLanguage(ctx context.Context, userID, language string) // UpdateInfo update user info func (ur *userRepo) UpdateInfo(ctx context.Context, userInfo *entity.User) (err error) { - _, err = ur.data.DB.Where("id = ?", userInfo.ID). + _, err = ur.data.DB.Context(ctx).Where("id = ?", userInfo.ID). Cols("username", "display_name", "avatar", "bio", "bio_html", "website", "location").Update(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -138,7 +139,7 @@ func (ur *userRepo) UpdateInfo(ctx context.Context, userInfo *entity.User) (err // GetByUserID get user info by user id func (ur *userRepo) GetByUserID(ctx context.Context, userID string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} - exist, err = ur.data.DB.Where("id = ?", userID).Get(userInfo) + exist, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return @@ -152,7 +153,7 @@ func (ur *userRepo) GetByUserID(ctx context.Context, userID string) (userInfo *e func (ur *userRepo) BatchGetByID(ctx context.Context, ids []string) ([]*entity.User, error) { list := make([]*entity.User, 0) - err := ur.data.DB.In("id", ids).Find(&list) + err := ur.data.DB.Context(ctx).In("id", ids).Find(&list) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -163,7 +164,7 @@ func (ur *userRepo) BatchGetByID(ctx context.Context, ids []string) ([]*entity.U // GetByUsername get user by username func (ur *userRepo) GetByUsername(ctx context.Context, username string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} - exist, err = ur.data.DB.Where("username = ?", username).Get(userInfo) + exist, err = ur.data.DB.Context(ctx).Where("username = ?", username).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return @@ -178,7 +179,7 @@ func (ur *userRepo) GetByUsername(ctx context.Context, username string) (userInf // GetByEmail get user by email func (ur *userRepo) GetByEmail(ctx context.Context, email string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} - exist, err = ur.data.DB.Where("e_mail = ?", email). + exist, err = ur.data.DB.Context(ctx).Where("e_mail = ?", email). Where("status != ?", entity.UserStatusDeleted).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() @@ -188,7 +189,7 @@ func (ur *userRepo) GetByEmail(ctx context.Context, email string) (userInfo *ent func (ur *userRepo) GetUserCount(ctx context.Context) (count int64, err error) { list := make([]*entity.User, 0) - count, err = ur.data.DB.Where("mail_status =?", entity.EmailStatusAvailable).And("status =?", entity.UserStatusAvailable).FindAndCount(&list) + count, err = ur.data.DB.Context(ctx).Where("mail_status =?", entity.EmailStatusAvailable).And("status =?", entity.UserStatusAvailable).FindAndCount(&list) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -205,7 +206,7 @@ func tryToDecorateUserInfoFromUserCenter(ctx context.Context, data *data.Data, o } userInfo := &entity.UserExternalLogin{} - session := data.DB.Where("user_id = ?", original.ID) + session := data.DB.Context(ctx).Where("user_id = ?", original.ID) session.Where("provider = ?", uc.Info().SlugName) exist, err := session.Get(userInfo) if err != nil { @@ -243,7 +244,7 @@ func tryToDecorateUserListFromUserCenter(ctx context.Context, data *data.Data, o } userExternalLoginList := make([]*entity.UserExternalLogin, 0) - session := data.DB.Where("provider = ?", uc.Info().SlugName) + session := data.DB.Context(ctx).Where("provider = ?", uc.Info().SlugName) session.In("user_id", ids) err := session.Find(&userExternalLoginList) if err != nil { diff --git a/internal/repo/user_external_login/user_external_login_repo.go b/internal/repo/user_external_login/user_external_login_repo.go index f7be790e..9c7c5e22 100644 --- a/internal/repo/user_external_login/user_external_login_repo.go +++ b/internal/repo/user_external_login/user_external_login_repo.go @@ -26,7 +26,7 @@ func NewUserExternalLoginRepo(data *data.Data) user_external_login.UserExternalL // AddUserExternalLogin add external login information func (ur *userExternalLoginRepo) AddUserExternalLogin(ctx context.Context, user *entity.UserExternalLogin) (err error) { - _, err = ur.data.DB.Insert(user) + _, err = ur.data.DB.Context(ctx).Insert(user) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -35,7 +35,7 @@ func (ur *userExternalLoginRepo) AddUserExternalLogin(ctx context.Context, user // UpdateInfo update user info func (ur *userExternalLoginRepo) UpdateInfo(ctx context.Context, userInfo *entity.UserExternalLogin) (err error) { - _, err = ur.data.DB.ID(userInfo.ID).Update(userInfo) + _, err = ur.data.DB.Context(ctx).ID(userInfo.ID).Update(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -46,7 +46,7 @@ func (ur *userExternalLoginRepo) UpdateInfo(ctx context.Context, userInfo *entit func (ur *userExternalLoginRepo) GetByExternalID(ctx context.Context, provider, externalID string) ( userInfo *entity.UserExternalLogin, exist bool, err error) { userInfo = &entity.UserExternalLogin{} - exist, err = ur.data.DB.Where("external_id = ?", externalID).Where("provider = ?", provider).Get(userInfo) + exist, err = ur.data.DB.Context(ctx).Where("external_id = ?", externalID).Where("provider = ?", provider).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -57,7 +57,7 @@ func (ur *userExternalLoginRepo) GetByExternalID(ctx context.Context, provider, func (ur *userExternalLoginRepo) GetUserExternalLoginList(ctx context.Context, userID string) ( resp []*entity.UserExternalLogin, err error) { resp = make([]*entity.UserExternalLogin, 0) - err = ur.data.DB.Where("user_id = ?", userID).Find(&resp) + err = ur.data.DB.Context(ctx).Where("user_id = ?", userID).Find(&resp) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } @@ -67,7 +67,7 @@ func (ur *userExternalLoginRepo) GetUserExternalLoginList(ctx context.Context, u // DeleteUserExternalLogin delete external user login info func (ur *userExternalLoginRepo) DeleteUserExternalLogin(ctx context.Context, userID, externalID string) (err error) { cond := &entity.UserExternalLogin{} - _, err = ur.data.DB.Where("user_id = ? AND external_id = ?", userID, externalID).Delete(cond) + _, err = ur.data.DB.Context(ctx).Where("user_id = ? AND external_id = ?", userID, externalID).Delete(cond) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } diff --git a/internal/service/activity_common/follow.go b/internal/service/activity_common/follow.go index baeff38c..5d448868 100644 --- a/internal/service/activity_common/follow.go +++ b/internal/service/activity_common/follow.go @@ -6,5 +6,5 @@ type FollowRepo interface { GetFollowIDs(ctx context.Context, userID, objectType string) (followIDs []string, err error) GetFollowAmount(ctx context.Context, objectID string) (followAmount int, err error) GetFollowUserIDs(ctx context.Context, objectID string) (userIDs []string, err error) - IsFollowed(userId, objectId string) (bool, error) + IsFollowed(ctx context.Context, userId, objectId string) (bool, error) } diff --git a/internal/service/config/config_service.go b/internal/service/config/config_service.go index 882b8fa5..b5af20ae 100644 --- a/internal/service/config/config_service.go +++ b/internal/service/config/config_service.go @@ -1,5 +1,7 @@ package config +import "context" + // ConfigRepo config repository type ConfigRepo interface { Get(key string) (interface{}, error) @@ -8,7 +10,7 @@ type ConfigRepo interface { GetArrayString(key string) ([]string, error) GetConfigType(key string) (int, error) GetJsonConfigByIDAndSetToObject(id int, value any) (err error) - SetConfig(key, value string) (err error) + SetConfig(ctx context.Context, key, value string) (err error) } // ConfigService user service diff --git a/internal/service/export/email_service.go b/internal/service/export/email_service.go index 4fc9a2fa..4363bdd9 100644 --- a/internal/service/export/email_service.go +++ b/internal/service/export/email_service.go @@ -386,7 +386,7 @@ func (es *EmailService) GetEmailConfig() (ec *EmailConfig, err error) { } // SetEmailConfig set email config -func (es *EmailService) SetEmailConfig(ec *EmailConfig) (err error) { +func (es *EmailService) SetEmailConfig(ctx context.Context, ec *EmailConfig) (err error) { data, _ := json.Marshal(ec) - return es.configRepo.SetConfig("email.config", string(data)) + return es.configRepo.SetConfig(ctx, "email.config", string(data)) } diff --git a/internal/service/plugin_common/plugin_common_service.go b/internal/service/plugin_common/plugin_common_service.go index c8f7d799..69ad45bb 100644 --- a/internal/service/plugin_common/plugin_common_service.go +++ b/internal/service/plugin_common/plugin_common_service.go @@ -70,7 +70,7 @@ func (ps *PluginCommonService) UpdatePluginStatus(ctx context.Context) (err erro if err != nil { return errors.InternalServer(reason.UnknownError).WithError(err) } - return ps.configRepo.SetConfig(constant.PluginStatus, string(content)) + return ps.configRepo.SetConfig(ctx, constant.PluginStatus, string(content)) } // UpdatePluginConfig update plugin config diff --git a/internal/service/question_common/question.go b/internal/service/question_common/question.go index ce5b77e5..012e5e83 100644 --- a/internal/service/question_common/question.go +++ b/internal/service/question_common/question.go @@ -227,7 +227,7 @@ func (qs *QuestionCommon) Info(ctx context.Context, questionID string, loginUser showinfo.VoteStatus = qs.voteRepo.GetVoteStatus(ctx, questionID, loginUserID) // // check is followed - isFollowed, _ := qs.followCommon.IsFollowed(loginUserID, questionID) + isFollowed, _ := qs.followCommon.IsFollowed(ctx, loginUserID, questionID) showinfo.IsFollowed = isFollowed has, err = qs.AnswerCommon.SearchAnswered(ctx, loginUserID, dbinfo.ID) diff --git a/internal/service/report_admin/report_backyard.go b/internal/service/report_admin/report_backyard.go index 51065995..c3822ebc 100644 --- a/internal/service/report_admin/report_backyard.go +++ b/internal/service/report_admin/report_backyard.go @@ -4,13 +4,13 @@ import ( "context" "github.com/answerdev/answer/internal/service/config" + "github.com/answerdev/answer/internal/service/object_info" "github.com/answerdev/answer/pkg/htmltext" "github.com/segmentfault/pacman/log" "github.com/answerdev/answer/internal/base/pager" "github.com/answerdev/answer/internal/base/reason" "github.com/answerdev/answer/internal/entity" - "github.com/answerdev/answer/internal/repo/common" "github.com/answerdev/answer/internal/schema" answercommon "github.com/answerdev/answer/internal/service/answer_common" "github.com/answerdev/answer/internal/service/comment_common" @@ -26,33 +26,33 @@ import ( type ReportAdminService struct { reportRepo report_common.ReportRepo commonUser *usercommon.UserCommon - commonRepo *common.CommonRepo answerRepo answercommon.AnswerRepo questionRepo questioncommon.QuestionRepo commentCommonRepo comment_common.CommentCommonRepo reportHandle *report_handle_admin.ReportHandle configRepo config.ConfigRepo + objectInfoService *object_info.ObjService } // NewReportAdminService new report service func NewReportAdminService( reportRepo report_common.ReportRepo, commonUser *usercommon.UserCommon, - commonRepo *common.CommonRepo, answerRepo answercommon.AnswerRepo, questionRepo questioncommon.QuestionRepo, commentCommonRepo comment_common.CommentCommonRepo, reportHandle *report_handle_admin.ReportHandle, - configRepo config.ConfigRepo) *ReportAdminService { + configRepo config.ConfigRepo, + objectInfoService *object_info.ObjService) *ReportAdminService { return &ReportAdminService{ reportRepo: reportRepo, commonUser: commonUser, - commonRepo: commonRepo, answerRepo: answerRepo, questionRepo: questionRepo, commentCommonRepo: commentCommonRepo, reportHandle: reportHandle, configRepo: configRepo, + objectInfoService: objectInfoService, } } @@ -98,9 +98,8 @@ func (rs *ReportAdminService) ListReportPage(ctx context.Context, dto schema.Get for _, r := range resp { r.ReportedUser = flaggedUsers[r.ReportedUserID] r.ReportUser = users[r.UserID] + rs.decorateReportResp(ctx, r) } - - rs.parseObject(ctx, &resp) return pager.NewPageModel(total, resp), nil } @@ -139,99 +138,31 @@ func (rs *ReportAdminService) HandleReported(ctx context.Context, req schema.Rep return } -func (rs *ReportAdminService) parseObject(ctx context.Context, resp *[]*schema.GetReportListPageResp) { - var ( - res = *resp - ) +func (rs *ReportAdminService) decorateReportResp(ctx context.Context, resp *schema.GetReportListPageResp) { + objectInfo, err := rs.objectInfoService.GetInfo(ctx, resp.ObjectID) + if err != nil { + log.Error(err) + return + } - for i, r := range res { - var ( - objIds map[string]string - exists, - ok bool - err error - questionId, - answerId, - commentId string - question *entity.Question - answer *entity.Answer - cmt *entity.Comment - ) + resp.QuestionID = objectInfo.QuestionID + resp.AnswerID = objectInfo.AnswerID + resp.CommentID = objectInfo.CommentID + resp.Title = objectInfo.Title + resp.Excerpt = htmltext.FetchExcerpt(objectInfo.Content, "...", 240) - objIds, err = rs.commonRepo.GetObjectIDMap(r.ObjectID) + if resp.ReportType > 0 { + resp.Reason = &schema.ReasonItem{ReasonType: resp.ReportType} + err = rs.configRepo.GetJsonConfigByIDAndSetToObject(resp.ReportType, resp.Reason) if err != nil { log.Error(err) - continue } - - questionId, ok = objIds["question"] - if !ok { - continue + } + if resp.FlaggedType > 0 { + resp.FlaggedReason = &schema.ReasonItem{ReasonType: resp.FlaggedType} + err = rs.configRepo.GetJsonConfigByIDAndSetToObject(resp.FlaggedType, resp.FlaggedReason) + if err != nil { + log.Error(err) } - - question, exists, err = rs.questionRepo.GetQuestion(ctx, questionId) - if err != nil || !exists { - continue - } - - answerId, ok = objIds["answer"] - if ok { - answer, _, err = rs.answerRepo.GetAnswer(ctx, answerId) - if err != nil { - log.Error(err) - continue - } - } - - commentId, ok = objIds["comment"] - if ok { - cmt, _, err = rs.commentCommonRepo.GetComment(ctx, commentId) - if err != nil { - log.Error(err) - continue - } - } - - switch r.OType { - case "question": - r.QuestionID = questionId - r.Title = question.Title - r.Excerpt = htmltext.FetchExcerpt(question.ParsedText, "...", 240) - - case "answer": - r.QuestionID = questionId - r.AnswerID = answerId - r.Title = question.Title - r.Excerpt = htmltext.FetchExcerpt(answer.ParsedText, "...", 240) - - case "comment": - r.QuestionID = questionId - r.AnswerID = answerId - r.CommentID = commentId - r.Title = question.Title - r.Excerpt = htmltext.FetchExcerpt(cmt.ParsedText, "...", 240) - } - - // parse reason - if r.ReportType > 0 { - r.Reason = &schema.ReasonItem{ - ReasonType: r.ReportType, - } - err = rs.configRepo.GetJsonConfigByIDAndSetToObject(r.ReportType, r.Reason) - if err != nil { - log.Error(err) - } - } - if r.FlaggedType > 0 { - r.FlaggedReason = &schema.ReasonItem{ - ReasonType: r.FlaggedType, - } - err = rs.configRepo.GetJsonConfigByIDAndSetToObject(r.FlaggedType, r.FlaggedReason) - if err != nil { - log.Error(err) - } - } - - res[i] = r } } diff --git a/internal/service/siteinfo/siteinfo_service.go b/internal/service/siteinfo/siteinfo_service.go index 986beb78..7465b608 100644 --- a/internal/service/siteinfo/siteinfo_service.go +++ b/internal/service/siteinfo/siteinfo_service.go @@ -254,7 +254,7 @@ func (s *SiteInfoService) UpdateSMTPConfig(ctx context.Context, req *schema.Upda } _ = copier.Copy(oldEmailConfig, req) - err = s.emailService.SetEmailConfig(oldEmailConfig) + err = s.emailService.SetEmailConfig(ctx, oldEmailConfig) if err != nil { return err } @@ -370,7 +370,7 @@ func (s *SiteInfoService) UpdatePrivilegesConfig(ctx context.Context, req *schem // update privilege in config for _, privilege := range chooseOption.Privileges { - err = s.configRepo.SetConfig(privilege.Key, fmt.Sprintf("%d", privilege.Value)) + err = s.configRepo.SetConfig(ctx, privilege.Key, fmt.Sprintf("%d", privilege.Value)) if err != nil { return err } diff --git a/internal/service/tag/tag_service.go b/internal/service/tag/tag_service.go index e27cbc7d..23dc25de 100644 --- a/internal/service/tag/tag_service.go +++ b/internal/service/tag/tag_service.go @@ -380,7 +380,7 @@ func (ts *TagService) checkTagIsFollow(ctx context.Context, userID, tagID string if len(userID) == 0 { return false } - followed, err := ts.followCommon.IsFollowed(userID, tagID) + followed, err := ts.followCommon.IsFollowed(ctx, userID, tagID) if err != nil { log.Error(err) }