plugin.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. // Copyright 2018-present the CoreDHCP Authors. All rights reserved
  2. // This source code is licensed under the MIT license found in the
  3. // LICENSE file in the root directory of this source tree.
  4. package rangeplugin
  5. import (
  6. "bufio"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math/rand"
  12. "net"
  13. "os"
  14. "strings"
  15. "time"
  16. "github.com/coredhcp/coredhcp/handler"
  17. "github.com/coredhcp/coredhcp/logger"
  18. "github.com/coredhcp/coredhcp/plugins"
  19. "github.com/insomniacslk/dhcp/dhcpv4"
  20. "github.com/insomniacslk/dhcp/dhcpv6"
  21. )
  22. var log = logger.GetLogger("plugins/range")
  23. func init() {
  24. plugins.RegisterPlugin("range", setupRange6, setupRange4)
  25. }
  26. //Record holds an IP lease record
  27. type Record struct {
  28. IP net.IP
  29. expires time.Time
  30. }
  31. // various global variables
  32. var (
  33. // Recordsv4 holds a MAC -> IP address and lease time mapping
  34. Recordsv4 map[string]*Record
  35. Recordsv6 map[string]*Record
  36. LeaseTime time.Duration
  37. filename string
  38. ipRangeStart net.IP
  39. ipRangeEnd net.IP
  40. )
  41. // loadRecords loads the DHCPv6/v4 Records global map with records stored on
  42. // the specified file. The records have to be one per line, a mac address and an
  43. // IP address.
  44. func loadRecords(r io.Reader, v6 bool) (map[string]*Record, error) {
  45. sc := bufio.NewScanner(r)
  46. records := make(map[string]*Record)
  47. for sc.Scan() {
  48. line := sc.Text()
  49. if len(line) == 0 {
  50. continue
  51. }
  52. tokens := strings.Fields(line)
  53. if len(tokens) != 3 {
  54. return nil, fmt.Errorf("malformed line, want 3 fields, got %d: %s", len(tokens), line)
  55. }
  56. hwaddr, err := net.ParseMAC(tokens[0])
  57. if err != nil {
  58. return nil, fmt.Errorf("malformed hardware address: %s", tokens[0])
  59. }
  60. ipaddr := net.ParseIP(tokens[1])
  61. if v6 {
  62. if len(ipaddr) == net.IPv6len {
  63. return nil, fmt.Errorf("expected an IPv6 address, got: %v", ipaddr)
  64. }
  65. } else {
  66. if ipaddr.To4() == nil {
  67. return nil, fmt.Errorf("expected an IPv4 address, got: %v", ipaddr)
  68. }
  69. }
  70. expires, err := time.Parse(time.RFC3339, tokens[2])
  71. if err != nil {
  72. return nil, fmt.Errorf("expected time of exipry in RFC3339 format, got: %v", tokens[2])
  73. }
  74. records[hwaddr.String()] = &Record{IP: ipaddr, expires: expires}
  75. }
  76. return records, nil
  77. }
  78. // Handler6 handles DHCPv6 packets for the file plugin
  79. func Handler6(req, resp dhcpv6.DHCPv6) (dhcpv6.DHCPv6, bool) {
  80. // TODO add IPv6 netmask to the response
  81. return resp, false
  82. }
  83. // Handler4 handles DHCPv4 packets for the range plugin
  84. func Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool) {
  85. record, ok := Recordsv4[req.ClientHWAddr.String()]
  86. if !ok {
  87. log.Printf("MAC address %s is new, leasing new IPv4 address", req.ClientHWAddr.String())
  88. rec, err := createIP(ipRangeStart, ipRangeEnd)
  89. if err != nil {
  90. log.Error(err)
  91. return nil, true
  92. }
  93. err = saveIPAddress(req.ClientHWAddr, rec)
  94. if err != nil {
  95. log.Printf("SaveIPAddress for MAC %s failed: %v", req.ClientHWAddr.String(), err)
  96. }
  97. Recordsv4[req.ClientHWAddr.String()] = rec
  98. record = rec
  99. }
  100. resp.YourIPAddr = record.IP
  101. resp.Options.Update(dhcpv4.OptIPAddressLeaseTime(LeaseTime))
  102. log.Printf("found IP address %s for MAC %s", record.IP, req.ClientHWAddr.String())
  103. return resp, false
  104. }
  105. func setupRange6(args ...string) (handler.Handler6, error) {
  106. // TODO setup function for IPv6
  107. log.Warning("not implemented for IPv6")
  108. return Handler6, nil
  109. }
  110. func setupRange4(args ...string) (handler.Handler4, error) {
  111. _, h4, err := setupRange(false, args...)
  112. return h4, err
  113. }
  114. func setupRange(v6 bool, args ...string) (handler.Handler6, handler.Handler4, error) {
  115. var err error
  116. if len(args) < 4 {
  117. return nil, nil, fmt.Errorf("invalid number of arguments, want: 4 (file name, start IP, end IP, lease time), got: %d", len(args))
  118. }
  119. filename = args[0]
  120. if filename == "" {
  121. return nil, nil, errors.New("file name cannot be empty")
  122. }
  123. ipRangeStart = net.ParseIP(args[1])
  124. if ipRangeStart.To4() == nil {
  125. return nil, nil, fmt.Errorf("invalid IPv4 address: %v", args[1])
  126. }
  127. ipRangeEnd = net.ParseIP(args[2])
  128. if ipRangeEnd.To4() == nil {
  129. return nil, nil, fmt.Errorf("invalid IPv4 address: %v", args[2])
  130. }
  131. if binary.BigEndian.Uint32(ipRangeStart.To4()) >= binary.BigEndian.Uint32(ipRangeEnd.To4()) {
  132. return nil, nil, errors.New("start of IP range has to be lower than the end of an IP range")
  133. }
  134. LeaseTime, err = time.ParseDuration(args[3])
  135. if err != nil {
  136. return Handler6, Handler4, fmt.Errorf("invalid duration: %v", args[3])
  137. }
  138. r, err := os.Open(filename)
  139. defer func() {
  140. if err := r.Close(); err != nil {
  141. log.Warningf("Failed to close file %s: %v", filename, err)
  142. }
  143. }()
  144. if err != nil {
  145. return nil, nil, fmt.Errorf("cannot open lease file %s: %v", filename, err)
  146. }
  147. if v6 {
  148. Recordsv6, err = loadRecords(r, true)
  149. } else {
  150. Recordsv4, err = loadRecords(r, false)
  151. }
  152. if err != nil {
  153. return nil, nil, fmt.Errorf("failed to load records: %v", err)
  154. }
  155. rand.Seed(time.Now().Unix())
  156. if v6 {
  157. log.Printf("Loaded %d DHCPv6 leases from %s", len(Recordsv6), filename)
  158. } else {
  159. log.Printf("Loaded %d DHCPv4 leases from %s", len(Recordsv4), filename)
  160. }
  161. return Handler6, Handler4, nil
  162. }
  163. // createIP allocates a new lease in the provided range.
  164. // TODO this is not concurrency-safe
  165. func createIP(rangeStart net.IP, rangeEnd net.IP) (*Record, error) {
  166. ip := make([]byte, 4)
  167. rangeStartInt := binary.BigEndian.Uint32(rangeStart.To4())
  168. rangeEndInt := binary.BigEndian.Uint32(rangeEnd.To4())
  169. binary.BigEndian.PutUint32(ip, random(rangeStartInt, rangeEndInt))
  170. taken := checkIfTaken(ip)
  171. for taken {
  172. ipInt := binary.BigEndian.Uint32(ip)
  173. ipInt++
  174. binary.BigEndian.PutUint32(ip, ipInt)
  175. if ipInt > rangeEndInt {
  176. break
  177. }
  178. taken = checkIfTaken(ip)
  179. }
  180. for taken {
  181. ipInt := binary.BigEndian.Uint32(ip)
  182. ipInt--
  183. binary.BigEndian.PutUint32(ip, ipInt)
  184. if ipInt < rangeStartInt {
  185. return &Record{}, errors.New("no new IP addresses available")
  186. }
  187. taken = checkIfTaken(ip)
  188. }
  189. return &Record{IP: ip, expires: time.Now().Add(LeaseTime)}, nil
  190. }
  191. func random(min uint32, max uint32) uint32 {
  192. return uint32(rand.Intn(int(max-min))) + min
  193. }
  194. // check if an IP address is already leased. DHCPv4 only.
  195. func checkIfTaken(ip net.IP) bool {
  196. taken := false
  197. for _, v := range Recordsv4 {
  198. if v.IP.String() == ip.String() && (v.expires.After(time.Now())) {
  199. taken = true
  200. break
  201. }
  202. }
  203. return taken
  204. }
  205. func saveIPAddress(mac net.HardwareAddr, record *Record) error {
  206. f, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
  207. if err != nil {
  208. return err
  209. }
  210. defer f.Close()
  211. _, err = f.WriteString(mac.String() + " " + record.IP.String() + " " + record.expires.Format(time.RFC3339) + "\n")
  212. if err != nil {
  213. return err
  214. }
  215. err = f.Sync()
  216. if err != nil {
  217. return err
  218. }
  219. return nil
  220. }