Explorar o código

[manager] protect DB with RW lock

z4yx %!s(int64=5) %!d(string=hai) anos
pai
achega
1b099520b2
Modificáronse 2 ficheiros con 61 adicións e 2 borrados
  1. 32 2
      manager/server.go
  2. 29 0
      manager/server_test.go

+ 32 - 2
manager/server.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"sync"
 	"time"
 
 	"github.com/gin-gonic/gin"
@@ -23,6 +24,7 @@ type Manager struct {
 	cfg        *Config
 	engine     *gin.Engine
 	adapter    dbAdapter
+	rwmu       sync.RWMutex
 	httpClient *http.Client
 }
 
@@ -127,9 +129,11 @@ func (s *Manager) Run() {
 	}
 }
 
-// listAllJobs repond with all jobs of specified workers
+// listAllJobs respond with all jobs of specified workers
 func (s *Manager) listAllJobs(c *gin.Context) {
+	s.rwmu.RLock()
 	mirrorStatusList, err := s.adapter.ListAllMirrorStatus()
+	s.rwmu.RUnlock()
 	if err != nil {
 		err := fmt.Errorf("failed to list all mirror status: %s",
 			err.Error(),
@@ -150,7 +154,9 @@ func (s *Manager) listAllJobs(c *gin.Context) {
 
 // flushDisabledJobs deletes all jobs that marks as deleted
 func (s *Manager) flushDisabledJobs(c *gin.Context) {
+	s.rwmu.Lock()
 	err := s.adapter.FlushDisabledJobs()
+	s.rwmu.Unlock()
 	if err != nil {
 		err := fmt.Errorf("failed to flush disabled jobs: %s",
 			err.Error(),
@@ -165,7 +171,9 @@ func (s *Manager) flushDisabledJobs(c *gin.Context) {
 // deleteWorker deletes one worker by id
 func (s *Manager) deleteWorker(c *gin.Context) {
 	workerID := c.Param("id")
+	s.rwmu.Lock()
 	err := s.adapter.DeleteWorker(workerID)
+	s.rwmu.Unlock()
 	if err != nil {
 		err := fmt.Errorf("failed to delete worker: %s",
 			err.Error(),
@@ -178,10 +186,12 @@ func (s *Manager) deleteWorker(c *gin.Context) {
 	c.JSON(http.StatusOK, gin.H{_infoKey: "deleted"})
 }
 
-// listWrokers respond with informations of all the workers
+// listWorkers respond with information of all the workers
 func (s *Manager) listWorkers(c *gin.Context) {
 	var workerInfos []WorkerStatus
+	s.rwmu.RLock()
 	workers, err := s.adapter.ListWorkers()
+	s.rwmu.RUnlock()
 	if err != nil {
 		err := fmt.Errorf("failed to list workers: %s",
 			err.Error(),
@@ -223,7 +233,9 @@ func (s *Manager) registerWorker(c *gin.Context) {
 // listJobsOfWorker respond with all the jobs of the specified worker
 func (s *Manager) listJobsOfWorker(c *gin.Context) {
 	workerID := c.Param("id")
+	s.rwmu.RLock()
 	mirrorStatusList, err := s.adapter.ListMirrorStatus(workerID)
+	s.rwmu.RUnlock()
 	if err != nil {
 		err := fmt.Errorf("failed to list jobs of worker %s: %s",
 			workerID, err.Error(),
@@ -255,7 +267,9 @@ func (s *Manager) updateSchedulesOfWorker(c *gin.Context) {
 			)
 		}
 
+		s.rwmu.RLock()
 		curStatus, err := s.adapter.GetMirrorStatus(workerID, mirrorName)
+		s.rwmu.RUnlock()
 		if err != nil {
 			fmt.Errorf("failed to get job %s of worker %s: %s",
 				mirrorName, workerID, err.Error(),
@@ -269,7 +283,9 @@ func (s *Manager) updateSchedulesOfWorker(c *gin.Context) {
 		}
 
 		curStatus.Scheduled = schedule.NextSchedule
+		s.rwmu.Lock()
 		_, err = s.adapter.UpdateMirrorStatus(workerID, mirrorName, curStatus)
+		s.rwmu.Unlock()
 		if err != nil {
 			err := fmt.Errorf("failed to update job %s of worker %s: %s",
 				mirrorName, workerID, err.Error(),
@@ -295,7 +311,9 @@ func (s *Manager) updateJobOfWorker(c *gin.Context) {
 		)
 	}
 
+	s.rwmu.RLock()
 	curStatus, _ := s.adapter.GetMirrorStatus(workerID, mirrorName)
+	s.rwmu.RUnlock()
 
 	curTime := time.Now()
 
@@ -331,7 +349,9 @@ func (s *Manager) updateJobOfWorker(c *gin.Context) {
 		logger.Noticef("Job [%s] @<%s> %s", status.Name, status.Worker, status.Status)
 	}
 
+	s.rwmu.Lock()
 	newStatus, err := s.adapter.UpdateMirrorStatus(workerID, mirrorName, status)
+	s.rwmu.Unlock()
 	if err != nil {
 		err := fmt.Errorf("failed to update job %s of worker %s: %s",
 			mirrorName, workerID, err.Error(),
@@ -353,7 +373,9 @@ func (s *Manager) updateMirrorSize(c *gin.Context) {
 	c.BindJSON(&msg)
 
 	mirrorName := msg.Name
+	s.rwmu.RLock()
 	status, err := s.adapter.GetMirrorStatus(workerID, mirrorName)
+	s.rwmu.RUnlock()
 	if err != nil {
 		logger.Errorf(
 			"Failed to get status of mirror %s @<%s>: %s",
@@ -370,7 +392,9 @@ func (s *Manager) updateMirrorSize(c *gin.Context) {
 
 	logger.Noticef("Mirror size of [%s] @<%s>: %s", status.Name, status.Worker, status.Size)
 
+	s.rwmu.Lock()
 	newStatus, err := s.adapter.UpdateMirrorStatus(workerID, mirrorName, status)
+	s.rwmu.Unlock()
 	if err != nil {
 		err := fmt.Errorf("failed to update job %s of worker %s: %s",
 			mirrorName, workerID, err.Error(),
@@ -393,7 +417,9 @@ func (s *Manager) handleClientCmd(c *gin.Context) {
 		return
 	}
 
+	s.rwmu.RLock()
 	w, err := s.adapter.GetWorker(workerID)
+	s.rwmu.RUnlock()
 	if err != nil {
 		err := fmt.Errorf("worker %s is not registered yet", workerID)
 		s.returnErrJSON(c, http.StatusBadRequest, err)
@@ -410,7 +436,9 @@ func (s *Manager) handleClientCmd(c *gin.Context) {
 
 	// update job status, even if the job did not disable successfully,
 	// this status should be set as disabled
+	s.rwmu.RLock()
 	curStat, _ := s.adapter.GetMirrorStatus(clientCmd.WorkerID, clientCmd.MirrorID)
+	s.rwmu.RUnlock()
 	changed := false
 	switch clientCmd.Cmd {
 	case CmdDisable:
@@ -421,7 +449,9 @@ func (s *Manager) handleClientCmd(c *gin.Context) {
 		changed = true
 	}
 	if changed {
+		s.rwmu.Lock()
 		s.adapter.UpdateMirrorStatus(clientCmd.WorkerID, clientCmd.MirrorID, curStat)
+		s.rwmu.Unlock()
 	}
 
 	logger.Noticef("Posting command '%s %s' to <%s>", clientCmd.Cmd, clientCmd.MirrorID, clientCmd.WorkerID)

+ 29 - 0
manager/server_test.go

@@ -7,6 +7,7 @@ import (
 	"math/rand"
 	"net/http"
 	"strings"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -64,6 +65,34 @@ func TestHTTPServer(t *testing.T) {
 			So(msg[_errorKey], ShouldEqual, fmt.Sprintf("failed to list jobs of worker %s: %s", _magicBadWorkerID, "database fail"))
 		})
 
+		Convey("when register multiple workers", func(ctx C) {
+			N := 10
+			var cnt uint32
+			for i := 0; i < N; i++ {
+				go func(id int) {
+					w := WorkerStatus{
+						ID: fmt.Sprintf("worker%d", id),
+					}
+					resp, err := PostJSON(baseURL+"/workers", w, nil)
+					ctx.So(err, ShouldBeNil)
+					ctx.So(resp.StatusCode, ShouldEqual, http.StatusOK)
+					atomic.AddUint32(&cnt, 1)
+				}(i)
+			}
+			time.Sleep(2 * time.Second)
+			So(cnt, ShouldEqual, N)
+
+			Convey("list all workers", func(ctx C) {
+				resp, err := http.Get(baseURL + "/workers")
+				So(err, ShouldBeNil)
+				defer resp.Body.Close()
+				var actualResponseObj []WorkerStatus
+				err = json.NewDecoder(resp.Body).Decode(&actualResponseObj)
+				So(err, ShouldBeNil)
+				So(len(actualResponseObj), ShouldEqual, N+1)
+			})
+		})
+
 		Convey("when register a worker", func(ctx C) {
 			w := WorkerStatus{
 				ID: "test_worker1",