serve.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. // Copyright 2018-present the CoreDHCP Authors. All rights reserved
  2. // This source code is licensed under the MIT license found in the
  3. // LICENSE file in the root directory of this source tree.
  4. package server
  5. import (
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "golang.org/x/net/ipv4"
  11. "golang.org/x/net/ipv6"
  12. "github.com/coredhcp/coredhcp/config"
  13. "github.com/coredhcp/coredhcp/handler"
  14. "github.com/coredhcp/coredhcp/logger"
  15. "github.com/coredhcp/coredhcp/plugins"
  16. "github.com/insomniacslk/dhcp/dhcpv4/server4"
  17. "github.com/insomniacslk/dhcp/dhcpv6/server6"
  18. )
  19. var log = logger.GetLogger("server")
  20. type listener6 struct {
  21. *ipv6.PacketConn
  22. net.Interface
  23. handlers []handler.Handler6
  24. }
  25. type listener4 struct {
  26. *ipv4.PacketConn
  27. net.Interface
  28. handlers []handler.Handler4
  29. }
  30. type listener interface {
  31. io.Closer
  32. }
  33. // Servers contains state for a running server (with possibly multiple interfaces/listeners)
  34. type Servers struct {
  35. listeners []listener
  36. errors chan error
  37. }
  38. func listen4(a *net.UDPAddr) (*listener4, error) {
  39. var err error
  40. l4 := listener4{}
  41. udpConn, err := server4.NewIPv4UDPConn(a.Zone, a)
  42. if err != nil {
  43. return nil, err
  44. }
  45. l4.PacketConn = ipv4.NewPacketConn(udpConn)
  46. var ifi *net.Interface
  47. if a.Zone != "" {
  48. ifi, err = net.InterfaceByName(a.Zone)
  49. if err != nil {
  50. return nil, fmt.Errorf("DHCPv4: Listen could not find interface %s: %v", a.Zone, err)
  51. }
  52. l4.Interface = *ifi
  53. } else {
  54. // When not bound to an interface, we need the information in each
  55. // packet to know which interface it came on
  56. err = l4.SetControlMessage(ipv4.FlagInterface, true)
  57. if err != nil {
  58. return nil, err
  59. }
  60. }
  61. if a.IP.IsMulticast() {
  62. err = l4.JoinGroup(ifi, a)
  63. if err != nil {
  64. return nil, err
  65. }
  66. }
  67. return &l4, nil
  68. }
  69. func listen6(a *net.UDPAddr) (*listener6, error) {
  70. l6 := listener6{}
  71. udpconn, err := server6.NewIPv6UDPConn(a.Zone, a)
  72. if err != nil {
  73. return nil, err
  74. }
  75. l6.PacketConn = ipv6.NewPacketConn(udpconn)
  76. var ifi *net.Interface
  77. if a.Zone != "" {
  78. ifi, err = net.InterfaceByName(a.Zone)
  79. if err != nil {
  80. return nil, fmt.Errorf("DHCPv4: Listen could not find interface %s: %v", a.Zone, err)
  81. }
  82. l6.Interface = *ifi
  83. } else {
  84. // When not bound to an interface, we need the information in each
  85. // packet to know which interface it came on
  86. err = l6.SetControlMessage(ipv6.FlagInterface, true)
  87. if err != nil {
  88. return nil, err
  89. }
  90. }
  91. if a.IP.IsMulticast() {
  92. err = l6.JoinGroup(ifi, a)
  93. if err != nil {
  94. return nil, err
  95. }
  96. }
  97. return &l6, nil
  98. }
  99. // Start will start the server asynchronously. See `Wait` to wait until
  100. // the execution ends.
  101. func Start(config *config.Config) (*Servers, error) {
  102. handlers4, handlers6, err := plugins.LoadPlugins(config)
  103. if err != nil {
  104. return nil, err
  105. }
  106. srv := Servers{
  107. errors: make(chan error),
  108. }
  109. // listen
  110. if config.Server6 != nil {
  111. log.Println("Starting DHCPv6 server")
  112. for _, addr := range config.Server6.Addresses {
  113. var l6 *listener6
  114. l6, err = listen6(&addr)
  115. if err != nil {
  116. goto cleanup
  117. }
  118. l6.handlers = handlers6
  119. srv.listeners = append(srv.listeners, l6)
  120. go func() {
  121. srv.errors <- l6.Serve()
  122. }()
  123. }
  124. }
  125. if config.Server4 != nil {
  126. log.Println("Starting DHCPv4 server")
  127. for _, addr := range config.Server4.Addresses {
  128. var l4 *listener4
  129. l4, err = listen4(&addr)
  130. if err != nil {
  131. goto cleanup
  132. }
  133. l4.handlers = handlers4
  134. srv.listeners = append(srv.listeners, l4)
  135. go func() {
  136. srv.errors <- l4.Serve()
  137. }()
  138. }
  139. }
  140. return &srv, nil
  141. cleanup:
  142. srv.Close()
  143. return nil, err
  144. }
  145. // Wait waits until the end of the execution of the server.
  146. func (s *Servers) Wait() error {
  147. log.Debug("Waiting")
  148. errs := make([]error, 1, len(s.listeners))
  149. errs[0] = <-s.errors
  150. s.Close()
  151. // Wait for the other listeners to close
  152. for i := 1; i < len(s.listeners); i++ {
  153. errs = append(errs, <-s.errors)
  154. }
  155. return errors.Join(errs...)
  156. }
  157. // Close closes all listening connections
  158. func (s *Servers) Close() {
  159. for _, srv := range s.listeners {
  160. if srv != nil {
  161. srv.Close()
  162. }
  163. }
  164. }