瀏覽代碼

test: add RWLock to fix data race in MockDBAdapter

Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
Shengqi Chen 8 月之前
父節點
當前提交
15e87a5f48
共有 1 個文件被更改,包括 20 次插入0 次删除
  1. 20 0
      manager/server_test.go

+ 20 - 0
manager/server_test.go

@@ -7,6 +7,7 @@ import (
 	"math/rand"
 	"net/http"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"testing"
 	"time"
@@ -424,6 +425,8 @@ func TestHTTPServer(t *testing.T) {
 type mockDBAdapter struct {
 	workerStore map[string]WorkerStatus
 	statusStore map[string]MirrorStatus
+	workerLock  sync.RWMutex
+	statusLock  sync.RWMutex
 }
 
 func (b *mockDBAdapter) Init() error {
@@ -431,17 +434,22 @@ func (b *mockDBAdapter) Init() error {
 }
 
 func (b *mockDBAdapter) ListWorkers() ([]WorkerStatus, error) {
+	b.workerLock.RLock()
 	workers := make([]WorkerStatus, len(b.workerStore))
 	idx := 0
 	for _, w := range b.workerStore {
 		workers[idx] = w
 		idx++
 	}
+	b.workerLock.RUnlock()
 	return workers, nil
 }
 
 func (b *mockDBAdapter) GetWorker(workerID string) (WorkerStatus, error) {
+	b.workerLock.RLock()
+	defer b.workerLock.RUnlock()
 	w, ok := b.workerStore[workerID]
+
 	if !ok {
 		return WorkerStatus{}, fmt.Errorf("invalid workerId")
 	}
@@ -449,7 +457,9 @@ func (b *mockDBAdapter) GetWorker(workerID string) (WorkerStatus, error) {
 }
 
 func (b *mockDBAdapter) DeleteWorker(workerID string) error {
+	b.workerLock.Lock()
 	delete(b.workerStore, workerID)
+	b.workerLock.Unlock()
 	return nil
 }
 
@@ -458,7 +468,9 @@ func (b *mockDBAdapter) CreateWorker(w WorkerStatus) (WorkerStatus, error) {
 	// if ok {
 	// 	return workerStatus{}, fmt.Errorf("duplicate worker name")
 	// }
+	b.workerLock.Lock()
 	b.workerStore[w.ID] = w
+	b.workerLock.Unlock()
 	return w, nil
 }
 
@@ -473,7 +485,9 @@ func (b *mockDBAdapter) RefreshWorker(workerID string) (w WorkerStatus, err erro
 
 func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (MirrorStatus, error) {
 	id := mirrorID + "/" + workerID
+	b.statusLock.RLock()
 	status, ok := b.statusStore[id]
+	b.statusLock.RUnlock()
 	if !ok {
 		return MirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID)
 	}
@@ -487,7 +501,9 @@ func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status Mir
 	// }
 
 	id := mirrorID + "/" + workerID
+	b.statusLock.Lock()
 	b.statusStore[id] = status
+	b.statusLock.Unlock()
 	return status, nil
 }
 
@@ -497,19 +513,23 @@ func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]MirrorStatus, error
 	if workerID == _magicBadWorkerID {
 		return []MirrorStatus{}, fmt.Errorf("database fail")
 	}
+	b.statusLock.RLock()
 	for k, v := range b.statusStore {
 		if wID := strings.Split(k, "/")[1]; wID == workerID {
 			mirrorStatusList = append(mirrorStatusList, v)
 		}
 	}
+	b.statusLock.RUnlock()
 	return mirrorStatusList, nil
 }
 
 func (b *mockDBAdapter) ListAllMirrorStatus() ([]MirrorStatus, error) {
 	var mirrorStatusList []MirrorStatus
+	b.statusLock.RLock()
 	for _, v := range b.statusStore {
 		mirrorStatusList = append(mirrorStatusList, v)
 	}
+	b.statusLock.RUnlock()
 	return mirrorStatusList, nil
 }