瀏覽代碼

feature(worker): move command execution logic to a runner object

bigeagle 9 年之前
父節點
當前提交
276ab233c5
共有 3 個文件被更改,包括 149 次插入44 次删除
  1. 12 41
      worker/cmd_provider.go
  2. 25 3
      worker/provider_test.go
  3. 112 0
      worker/runner.go

+ 12 - 41
worker/cmd_provider.go

@@ -1,12 +1,10 @@
 package worker
 
 import (
+	"errors"
 	"os"
-	"os/exec"
-	"strings"
 
 	"github.com/anmitsu/go-shlex"
-	"github.com/codeskyblue/go-sh"
 )
 
 type cmdConfig struct {
@@ -21,8 +19,7 @@ type cmdProvider struct {
 	baseProvider
 	cmdConfig
 	command []string
-	cmd     *exec.Cmd
-	session *sh.Session
+	cmd     *cmdJob
 }
 
 func newCmdProvider(c cmdConfig) (*cmdProvider, error) {
@@ -49,39 +46,10 @@ func newCmdProvider(c cmdConfig) (*cmdProvider, error) {
 	return provider, nil
 }
 
-// Copied from go-sh
-func newEnviron(env map[string]string, inherit bool) []string { //map[string]string {
-	environ := make([]string, 0, len(env))
-	if inherit {
-		for _, line := range os.Environ() {
-			// if os environment and env collapses,
-			// omit the os one
-			k := strings.Split(line, "=")[0]
-			if _, ok := env[k]; ok {
-				continue
-			}
-			environ = append(environ, line)
-		}
-	}
-	for k, v := range env {
-		environ = append(environ, k+"="+v)
-	}
-	return environ
+func (p *cmdProvider) InitRunner() {
 }
 
-// TODO: implement this
 func (p *cmdProvider) Run() error {
-	if len(p.command) == 1 {
-		p.cmd = exec.Command(p.command[0])
-	} else if len(p.command) > 1 {
-		c := p.command[0]
-		args := p.command[1:]
-		p.cmd = exec.Command(c, args...)
-	} else if len(p.command) == 0 {
-		panic("Command length should be at least 1!")
-	}
-	p.cmd.Dir = p.WorkingDir()
-
 	env := map[string]string{
 		"TUNASYNC_MIRROR_NAME":  p.Name(),
 		"TUNASYNC_WORKING_DIR":  p.WorkingDir(),
@@ -91,14 +59,14 @@ func (p *cmdProvider) Run() error {
 	for k, v := range p.env {
 		env[k] = v
 	}
-	p.cmd.Env = newEnviron(env, true)
+	p.cmd = newCmdJob(p.command, p.WorkingDir(), env)
 
 	logFile, err := os.OpenFile(p.LogFile(), os.O_WRONLY|os.O_CREATE, 0644)
 	if err != nil {
 		return err
 	}
-	p.cmd.Stdout = logFile
-	p.cmd.Stderr = logFile
+	// defer logFile.Close()
+	p.cmd.SetLogFile(logFile)
 
 	return p.cmd.Start()
 }
@@ -107,9 +75,12 @@ func (p *cmdProvider) Wait() error {
 	return p.cmd.Wait()
 }
 
-// TODO: implement this
-func (p *cmdProvider) Terminate() {
-
+func (p *cmdProvider) Terminate() error {
+	if p.cmd == nil {
+		return errors.New("provider command job not initialized")
+	}
+	err := p.cmd.Terminate()
+	return err
 }
 
 // TODO: implement this

+ 25 - 3
worker/provider_test.go

@@ -6,6 +6,7 @@ import (
 	"os"
 	"path/filepath"
 	"testing"
+	"time"
 
 	. "github.com/smartystreets/goconvey/convey"
 )
@@ -64,7 +65,7 @@ func TestRsyncProvider(t *testing.T) {
 }
 
 func TestCmdProvider(t *testing.T) {
-	Convey("Command Provider should work", t, func() {
+	Convey("Command Provider should work", t, func(ctx C) {
 		tmpDir, err := ioutil.TempDir("", "tunasync")
 		defer os.RemoveAll(tmpDir)
 		So(err, ShouldBeNil)
@@ -112,7 +113,7 @@ echo $TUNASYNC_LOG_FILE
 
 			err = provider.Run()
 			So(err, ShouldBeNil)
-			err = provider.cmd.Wait()
+			err = provider.Wait()
 			So(err, ShouldBeNil)
 
 			loggedContent, err := ioutil.ReadFile(provider.LogFile())
@@ -130,9 +131,30 @@ echo $TUNASYNC_LOG_FILE
 
 			err = provider.Run()
 			So(err, ShouldBeNil)
-			err = provider.cmd.Wait()
+			err = provider.Wait()
 			So(err, ShouldNotBeNil)
 
 		})
+
+		Convey("If a long job is killed", func(ctx C) {
+			scriptContent := `#!/bin/bash
+sleep 5
+			`
+			err = ioutil.WriteFile(scriptFile, []byte(scriptContent), 0755)
+			So(err, ShouldBeNil)
+
+			err = provider.Run()
+			So(err, ShouldBeNil)
+
+			go func() {
+				err = provider.Wait()
+				ctx.So(err, ShouldNotBeNil)
+			}()
+
+			time.Sleep(2)
+			err = provider.Terminate()
+			So(err, ShouldBeNil)
+
+		})
 	})
 }

+ 112 - 0
worker/runner.go

@@ -0,0 +1,112 @@
+package worker
+
+import (
+	"errors"
+	"os"
+	"os/exec"
+	"strings"
+	"syscall"
+	"time"
+
+	"golang.org/x/sys/unix"
+)
+
+// runner is to run os commands giving command line, env and log file
+// it's an alternative to python-sh or go-sh
+// TODO: cgroup excution
+
+type cmdJob struct {
+	cmd        *exec.Cmd
+	workingDir string
+	env        map[string]string
+	logFile    *os.File
+	finished   chan struct{}
+}
+
+func newCmdJob(cmdAndArgs []string, workingDir string, env map[string]string) *cmdJob {
+	var cmd *exec.Cmd
+	if len(cmdAndArgs) == 1 {
+		cmd = exec.Command(cmdAndArgs[0])
+	} else if len(cmdAndArgs) > 1 {
+		c := cmdAndArgs[0]
+		args := cmdAndArgs[1:]
+		cmd = exec.Command(c, args...)
+	} else if len(cmdAndArgs) == 0 {
+		panic("Command length should be at least 1!")
+	}
+
+	cmd.Dir = workingDir
+	cmd.Env = newEnviron(env, true)
+
+	return &cmdJob{
+		cmd:        cmd,
+		workingDir: workingDir,
+		env:        env,
+	}
+}
+
+// start job and wait
+func (c *cmdJob) Run() error {
+	err := c.cmd.Start()
+	if err != nil {
+		return err
+	}
+	return c.Wait()
+}
+
+func (c *cmdJob) Start() error {
+	c.finished = make(chan struct{}, 1)
+	return c.cmd.Start()
+}
+
+func (c *cmdJob) Wait() error {
+	err := c.cmd.Wait()
+	c.finished <- struct{}{}
+	return err
+}
+
+func (c *cmdJob) SetLogFile(logFile *os.File) {
+	c.cmd.Stdout = logFile
+	c.cmd.Stderr = logFile
+}
+
+func (c *cmdJob) Terminate() error {
+	if c.cmd == nil {
+		return errors.New("Command not initialized")
+	}
+	if c.cmd.Process == nil {
+		return errors.New("No Process Running")
+	}
+	err := unix.Kill(c.cmd.Process.Pid, syscall.SIGTERM)
+	if err != nil {
+		return err
+	}
+
+	select {
+	case <-time.After(2 * time.Second):
+		unix.Kill(c.cmd.Process.Pid, syscall.SIGKILL)
+		return errors.New("SIGTERM failed to kill the job")
+	case <-c.finished:
+		return nil
+	}
+}
+
+// Copied from go-sh
+func newEnviron(env map[string]string, inherit bool) []string { //map[string]string {
+	environ := make([]string, 0, len(env))
+	if inherit {
+		for _, line := range os.Environ() {
+			// if os environment and env collapses,
+			// omit the os one
+			k := strings.Split(line, "=")[0]
+			if _, ok := env[k]; ok {
+				continue
+			}
+			environ = append(environ, line)
+		}
+	}
+	for k, v := range env {
+		environ = append(environ, k+"="+v)
+	}
+	return environ
+}