cgroup.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package worker
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. "strconv"
  9. "syscall"
  10. "time"
  11. "golang.org/x/sys/unix"
  12. "github.com/codeskyblue/go-sh"
  13. )
  14. type cgroupHook struct {
  15. emptyHook
  16. provider mirrorProvider
  17. basePath string
  18. baseGroup string
  19. created bool
  20. subsystem string
  21. memLimit string
  22. }
  23. func newCgroupHook(p mirrorProvider, basePath, baseGroup, subsystem, memLimit string) *cgroupHook {
  24. if basePath == "" {
  25. basePath = "/sys/fs/cgroup"
  26. }
  27. if baseGroup == "" {
  28. baseGroup = "tunasync"
  29. }
  30. if subsystem == "" {
  31. subsystem = "cpu"
  32. }
  33. return &cgroupHook{
  34. provider: p,
  35. basePath: basePath,
  36. baseGroup: baseGroup,
  37. subsystem: subsystem,
  38. }
  39. }
  40. func (c *cgroupHook) preExec() error {
  41. c.created = true
  42. if err := sh.Command("cgcreate", "-g", c.Cgroup()).Run(); err != nil {
  43. return err
  44. }
  45. if c.subsystem != "memory" {
  46. return nil
  47. }
  48. if c.memLimit != "" {
  49. gname := fmt.Sprintf("%s/%s", c.baseGroup, c.provider.Name())
  50. return sh.Command(
  51. "cgset", "-r",
  52. fmt.Sprintf("memory.limit_in_bytes=%s", c.memLimit),
  53. gname,
  54. ).Run()
  55. }
  56. return nil
  57. }
  58. func (c *cgroupHook) postExec() error {
  59. err := c.killAll()
  60. if err != nil {
  61. logger.Errorf("Error killing tasks: %s", err.Error())
  62. }
  63. c.created = false
  64. return sh.Command("cgdelete", c.Cgroup()).Run()
  65. }
  66. func (c *cgroupHook) Cgroup() string {
  67. name := c.provider.Name()
  68. return fmt.Sprintf("%s:%s/%s", c.subsystem, c.baseGroup, name)
  69. }
  70. func (c *cgroupHook) killAll() error {
  71. if !c.created {
  72. return nil
  73. }
  74. name := c.provider.Name()
  75. readTaskList := func() ([]int, error) {
  76. taskList := []int{}
  77. taskFile, err := os.Open(filepath.Join(c.basePath, c.subsystem, c.baseGroup, name, "tasks"))
  78. if err != nil {
  79. return taskList, err
  80. }
  81. defer taskFile.Close()
  82. scanner := bufio.NewScanner(taskFile)
  83. for scanner.Scan() {
  84. pid, err := strconv.Atoi(scanner.Text())
  85. if err != nil {
  86. return taskList, err
  87. }
  88. taskList = append(taskList, pid)
  89. }
  90. return taskList, nil
  91. }
  92. for i := 0; i < 4; i++ {
  93. if i == 3 {
  94. return errors.New("Unable to kill all child tasks")
  95. }
  96. taskList, err := readTaskList()
  97. if err != nil {
  98. return err
  99. }
  100. if len(taskList) == 0 {
  101. return nil
  102. }
  103. for _, pid := range taskList {
  104. // TODO: deal with defunct processes
  105. logger.Debugf("Killing process: %d", pid)
  106. unix.Kill(pid, syscall.SIGKILL)
  107. }
  108. // sleep 10ms for the first round, and 1.01s, 2.01s, 3.01s for the rest
  109. time.Sleep(time.Duration(i)*time.Second + 10*time.Millisecond)
  110. }
  111. return nil
  112. }