Quellcode durchsuchen

tests(manager): add tests for server.go, validate workerID in middleware

walkerning vor 9 Jahren
Ursprung
Commit
401b6a694e
3 geänderte Dateien mit 190 neuen und 42 gelöschten Zeilen
  1. 17 0
      manager/middleware.go
  2. 15 11
      manager/server.go
  3. 158 31
      manager/server_test.go

+ 17 - 0
manager/middleware.go

@@ -1,6 +1,9 @@
 package manager
 
 import (
+	"fmt"
+	"net/http"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -14,3 +17,17 @@ func contextErrorLogger(c *gin.Context) {
 	// pass on to the next middleware in chain
 	c.Next()
 }
+
+func (s *managerServer) workerIDValidator(c *gin.Context) {
+	workerID := c.Param("id")
+	_, err := s.adapter.GetWorker(workerID)
+	if err != nil {
+		// no worker named `workerID` exists
+		err := fmt.Errorf("invalid workerID %s", workerID)
+		s.returnErrJSON(c, http.StatusBadRequest, err)
+		c.Abort()
+		return
+	}
+	// pass on to the next middleware in chain
+	c.Next()
+}

+ 15 - 11
manager/server.go

@@ -20,10 +20,8 @@ const (
 )
 
 type worker struct {
-	// worker name
-	id string
-	// session token
-	token string
+	ID    string `json:"id"`    // worker name
+	Token string `json:"token"` // session token
 }
 
 var (
@@ -64,7 +62,7 @@ func (s *managerServer) listWorkers(c *gin.Context) {
 	}
 	for _, w := range workers {
 		workerInfos = append(workerInfos,
-			WorkerInfoMsg{w.id})
+			WorkerInfoMsg{w.ID})
 	}
 	c.JSON(http.StatusOK, workerInfos)
 }
@@ -85,7 +83,7 @@ func (s *managerServer) registerWorker(c *gin.Context) {
 	// create workerCmd channel for this worker
 	workerChannelMu.Lock()
 	defer workerChannelMu.Unlock()
-	workerChannels[_worker.id] = make(chan WorkerCmd, maxQueuedCmdNum)
+	workerChannels[_worker.ID] = make(chan WorkerCmd, maxQueuedCmdNum)
 	c.JSON(http.StatusOK, newWorker)
 }
 
@@ -200,8 +198,12 @@ func makeHTTPServer(debug bool) *managerServer {
 		gin.Default(),
 		nil,
 	}
+
+	// common log middleware
+	s.Use(contextErrorLogger)
+
 	s.GET("/ping", func(c *gin.Context) {
-		c.JSON(http.StatusOK, gin.H{"msg": "pong"})
+		c.JSON(http.StatusOK, gin.H{_infoKey: "pong"})
 	})
 	// list jobs, status page
 	s.GET("/jobs", s.listAllJobs)
@@ -209,15 +211,17 @@ func makeHTTPServer(debug bool) *managerServer {
 	// list workers
 	s.GET("/workers", s.listWorkers)
 	// worker online
-	s.POST("/workers/:id", s.registerWorker)
+	s.POST("/workers", s.registerWorker)
 
+	// workerID should be valid in this route group
+	workerValidateGroup := s.Group("/workers", s.workerIDValidator)
 	// get job list
-	s.GET("/workers/:id/jobs", s.listJobsOfWorker)
+	workerValidateGroup.GET(":id/jobs", s.listJobsOfWorker)
 	// post job status
-	s.POST("/workers/:id/jobs/:job", s.updateJobOfWorker)
+	workerValidateGroup.POST(":id/jobs/:job", s.updateJobOfWorker)
 
 	// worker command polling
-	s.GET("/workers/:id/cmd_stream", s.getCmdOfWorker)
+	workerValidateGroup.GET(":id/cmd_stream", s.getCmdOfWorker)
 
 	// for tunasynctl to post commands
 	s.POST("/cmd/", s.handleClientCmd)

+ 158 - 31
manager/server_test.go

@@ -1,6 +1,7 @@
 package manager
 
 import (
+	"bytes"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -11,8 +12,146 @@ import (
 	"time"
 
 	. "github.com/smartystreets/goconvey/convey"
+	. "github.com/tuna/tunasync/internal"
 )
 
+const (
+	_magicBadWorkerID = "magic_bad_worker_id"
+)
+
+func postJSON(url string, obj interface{}) (*http.Response, error) {
+	b := new(bytes.Buffer)
+	json.NewEncoder(b).Encode(obj)
+	return http.Post(url, "application/json; charset=utf-8", b)
+}
+
+func TestHTTPServer(t *testing.T) {
+	Convey("HTTP server should work", t, func() {
+		InitLogger(true, true, false)
+		s := makeHTTPServer(false)
+		So(s, ShouldNotBeNil)
+		s.setDBAdapter(&mockDBAdapter{
+			workerStore: map[string]worker{
+				_magicBadWorkerID: worker{
+					ID: _magicBadWorkerID,
+				}},
+			statusStore: make(map[string]mirrorStatus),
+		})
+		port := rand.Intn(10000) + 20000
+		baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)
+		go func() {
+			s.Run(fmt.Sprintf("127.0.0.1:%d", port))
+		}()
+		time.Sleep(50 * time.Microsecond)
+		resp, err := http.Get(baseURL + "/ping")
+		So(err, ShouldBeNil)
+		So(resp.StatusCode, ShouldEqual, http.StatusOK)
+		So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8")
+		defer resp.Body.Close()
+		body, err := ioutil.ReadAll(resp.Body)
+		So(err, ShouldBeNil)
+		var p map[string]string
+		err = json.Unmarshal(body, &p)
+		So(err, ShouldBeNil)
+		So(p[_infoKey], ShouldEqual, "pong")
+
+		Convey("when database fail", func() {
+			resp, err := http.Get(fmt.Sprintf("%s/workers/%s/jobs", baseURL, _magicBadWorkerID))
+			So(err, ShouldBeNil)
+			So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
+			defer resp.Body.Close()
+			var msg map[string]string
+			err = json.NewDecoder(resp.Body).Decode(&msg)
+			So(err, ShouldBeNil)
+			So(msg[_errorKey], ShouldEqual, fmt.Sprintf("failed to list jobs of worker %s: %s", _magicBadWorkerID, "database fail"))
+		})
+
+		Convey("when register a worker", func() {
+			w := worker{
+				ID: "test_worker1",
+			}
+			resp, err := postJSON(baseURL+"/workers", w)
+			So(err, ShouldBeNil)
+			So(resp.StatusCode, ShouldEqual, http.StatusOK)
+
+			Convey("list all workers", func() {
+				So(err, ShouldBeNil)
+				resp, err := http.Get(baseURL + "/workers")
+				So(err, ShouldBeNil)
+				defer resp.Body.Close()
+				var actualResponseObj []WorkerInfoMsg
+				err = json.NewDecoder(resp.Body).Decode(&actualResponseObj)
+				So(err, ShouldBeNil)
+				So(len(actualResponseObj), ShouldEqual, 2)
+			})
+
+			Convey("update mirror status of a existed worker", func() {
+				status := mirrorStatus{
+					Name:       "arch-sync1",
+					Worker:     "test_worker1",
+					IsMaster:   true,
+					Status:     Success,
+					LastUpdate: time.Now(),
+					Upstream:   "mirrors.tuna.tsinghua.edu.cn",
+					Size:       "3GB",
+				}
+				resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s", baseURL, status.Worker, status.Name), status)
+				So(err, ShouldBeNil)
+				So(resp.StatusCode, ShouldEqual, http.StatusOK)
+
+				Convey("list mirror status of an existed worker", func() {
+
+					expectedResponse, err := json.Marshal([]mirrorStatus{status})
+					So(err, ShouldBeNil)
+					resp, err := http.Get(baseURL + "/workers/test_worker1/jobs")
+					So(err, ShouldBeNil)
+					So(resp.StatusCode, ShouldEqual, http.StatusOK)
+					// err = json.NewDecoder(resp.Body).Decode(&mirrorStatusList)
+					body, err := ioutil.ReadAll(resp.Body)
+					defer resp.Body.Close()
+					So(err, ShouldBeNil)
+					So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse))
+				})
+
+				Convey("list all job status of all workers", func() {
+					expectedResponse, err := json.Marshal([]mirrorStatus{status})
+					So(err, ShouldBeNil)
+					resp, err := http.Get(baseURL + "/jobs")
+					So(err, ShouldBeNil)
+					So(resp.StatusCode, ShouldEqual, http.StatusOK)
+					body, err := ioutil.ReadAll(resp.Body)
+					defer resp.Body.Close()
+					So(err, ShouldBeNil)
+					So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse))
+
+				})
+			})
+
+			Convey("update mirror status of an inexisted worker", func() {
+				invalidWorker := "test_worker2"
+				status := mirrorStatus{
+					Name:       "arch-sync2",
+					Worker:     invalidWorker,
+					IsMaster:   true,
+					Status:     Success,
+					LastUpdate: time.Now(),
+					Upstream:   "mirrors.tuna.tsinghua.edu.cn",
+					Size:       "4GB",
+				}
+				resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s",
+					baseURL, status.Worker, status.Name), status)
+				So(err, ShouldBeNil)
+				So(resp.StatusCode, ShouldEqual, http.StatusBadRequest)
+				defer resp.Body.Close()
+				var msg map[string]string
+				err = json.NewDecoder(resp.Body).Decode(&msg)
+				So(err, ShouldBeNil)
+				So(msg[_errorKey], ShouldEqual, "invalid workerID "+invalidWorker)
+			})
+		})
+	})
+}
+
 type mockDBAdapter struct {
 	workerStore map[string]worker
 	statusStore map[string]mirrorStatus
@@ -31,23 +170,22 @@ func (b *mockDBAdapter) ListWorkers() ([]worker, error) {
 func (b *mockDBAdapter) GetWorker(workerID string) (worker, error) {
 	w, ok := b.workerStore[workerID]
 	if !ok {
-		return worker{}, fmt.Errorf("inexist workerId")
+		return worker{}, fmt.Errorf("invalid workerId")
 	}
 	return w, nil
 }
 
 func (b *mockDBAdapter) CreateWorker(w worker) (worker, error) {
-	_, ok := b.workerStore[w.id]
-	if ok {
-		return worker{}, fmt.Errorf("duplicate worker name")
-	}
-	b.workerStore[w.id] = w
+	// _, ok := b.workerStore[w.ID]
+	// if ok {
+	// 	return worker{}, fmt.Errorf("duplicate worker name")
+	// }
+	b.workerStore[w.ID] = w
 	return w, nil
 }
 
 func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus, error) {
-	// TODO: need to check worker exist first
-	id := workerID + "/" + mirrorID
+	id := mirrorID + "/" + workerID
 	status, ok := b.statusStore[id]
 	if !ok {
 		return mirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID)
@@ -56,13 +194,22 @@ func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus
 }
 
 func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status mirrorStatus) (mirrorStatus, error) {
-	id := workerID + "/" + mirrorID
+	// if _, ok := b.workerStore[workerID]; !ok {
+	// 	// unregistered worker
+	// 	return mirrorStatus{}, fmt.Errorf("invalid workerID %s", workerID)
+	// }
+
+	id := mirrorID + "/" + workerID
 	b.statusStore[id] = status
 	return status, nil
 }
 
 func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]mirrorStatus, error) {
 	var mirrorStatusList []mirrorStatus
+	// simulating a database fail
+	if workerID == _magicBadWorkerID {
+		return []mirrorStatus{}, fmt.Errorf("database fail")
+	}
 	for k, v := range b.statusStore {
 		if wID := strings.Split(k, "/")[1]; wID == workerID {
 			mirrorStatusList = append(mirrorStatusList, v)
@@ -79,26 +226,6 @@ func (b *mockDBAdapter) ListAllMirrorStatus() ([]mirrorStatus, error) {
 	return mirrorStatusList, nil
 }
 
-func TestHTTPServer(t *testing.T) {
-	Convey("HTTP server should work", t, func() {
-		s := makeHTTPServer(false)
-		So(s, ShouldNotBeNil)
-		port := rand.Intn(10000) + 20000
-		go func() {
-			s.Run(fmt.Sprintf("127.0.0.1:%d", port))
-		}()
-		time.Sleep(50 * time.Microsecond)
-		resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/ping", port))
-		So(err, ShouldBeNil)
-		So(resp.StatusCode, ShouldEqual, http.StatusOK)
-		So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8")
-		defer resp.Body.Close()
-		body, err := ioutil.ReadAll(resp.Body)
-		So(err, ShouldBeNil)
-		var p map[string]string
-		err = json.Unmarshal(body, &p)
-		So(err, ShouldBeNil)
-		So(p["msg"], ShouldEqual, "pong")
-	})
-
+func (b *mockDBAdapter) Close() error {
+	return nil
 }