cgroup.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package worker
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. "strconv"
  9. "syscall"
  10. "time"
  11. "golang.org/x/sys/unix"
  12. "github.com/codeskyblue/go-sh"
  13. )
  14. type cgroupHook struct {
  15. emptyHook
  16. basePath string
  17. baseGroup string
  18. created bool
  19. subsystem string
  20. memLimit MemBytes
  21. }
  22. func newCgroupHook(p mirrorProvider, basePath, baseGroup, subsystem string, memLimit MemBytes) *cgroupHook {
  23. if basePath == "" {
  24. basePath = "/sys/fs/cgroup"
  25. }
  26. if baseGroup == "" {
  27. baseGroup = "tunasync"
  28. }
  29. if subsystem == "" {
  30. subsystem = "cpu"
  31. }
  32. return &cgroupHook{
  33. emptyHook: emptyHook{
  34. provider: p,
  35. },
  36. basePath: basePath,
  37. baseGroup: baseGroup,
  38. subsystem: subsystem,
  39. }
  40. }
  41. func (c *cgroupHook) preExec() error {
  42. c.created = true
  43. if err := sh.Command("cgcreate", "-g", c.Cgroup()).Run(); err != nil {
  44. return err
  45. }
  46. if c.subsystem != "memory" {
  47. return nil
  48. }
  49. if c.memLimit != 0 {
  50. gname := fmt.Sprintf("%s/%s", c.baseGroup, c.provider.Name())
  51. return sh.Command(
  52. "cgset", "-r",
  53. fmt.Sprintf("memory.limit_in_bytes=%d", c.memLimit.Value()),
  54. gname,
  55. ).Run()
  56. }
  57. return nil
  58. }
  59. func (c *cgroupHook) postExec() error {
  60. err := c.killAll()
  61. if err != nil {
  62. logger.Errorf("Error killing tasks: %s", err.Error())
  63. }
  64. c.created = false
  65. return sh.Command("cgdelete", c.Cgroup()).Run()
  66. }
  67. func (c *cgroupHook) Cgroup() string {
  68. name := c.provider.Name()
  69. return fmt.Sprintf("%s:%s/%s", c.subsystem, c.baseGroup, name)
  70. }
  71. func (c *cgroupHook) killAll() error {
  72. if !c.created {
  73. return nil
  74. }
  75. name := c.provider.Name()
  76. readTaskList := func() ([]int, error) {
  77. taskList := []int{}
  78. taskFile, err := os.Open(filepath.Join(c.basePath, c.subsystem, c.baseGroup, name, "tasks"))
  79. if err != nil {
  80. return taskList, err
  81. }
  82. defer taskFile.Close()
  83. scanner := bufio.NewScanner(taskFile)
  84. for scanner.Scan() {
  85. pid, err := strconv.Atoi(scanner.Text())
  86. if err != nil {
  87. return taskList, err
  88. }
  89. taskList = append(taskList, pid)
  90. }
  91. return taskList, nil
  92. }
  93. for i := 0; i < 4; i++ {
  94. if i == 3 {
  95. return errors.New("Unable to kill all child tasks")
  96. }
  97. taskList, err := readTaskList()
  98. if err != nil {
  99. return err
  100. }
  101. if len(taskList) == 0 {
  102. return nil
  103. }
  104. for _, pid := range taskList {
  105. // TODO: deal with defunct processes
  106. logger.Debugf("Killing process: %d", pid)
  107. unix.Kill(pid, syscall.SIGKILL)
  108. }
  109. // sleep 10ms for the first round, and 1.01s, 2.01s, 3.01s for the rest
  110. time.Sleep(time.Duration(i)*time.Second + 10*time.Millisecond)
  111. }
  112. return nil
  113. }