Quellcode durchsuchen

refresh records when leases file is updated

Signed-off-by: Reinier Schoof <reinier@skoef.nl>
Reinier Schoof vor 4 Jahren
Ursprung
Commit
c780ba84df
4 geänderte Dateien mit 104 neuen und 16 gelöschten Zeilen
  1. 3 1
      cmds/coredhcp/config.yml.example
  2. 1 0
      go.mod
  3. 67 5
      plugins/file/plugin.go
  4. 33 10
      plugins/file/plugin_test.go

+ 3 - 1
cmds/coredhcp/config.yml.example

@@ -73,8 +73,10 @@ server6:
         - server_id: LL 00:de:ad:be:ef:00
 
         # file serves leases defined in a static file, matching link-layer addresses to IPs
-        # - file: <file name>
+        # - file: <file name> [autorefresh]
         # The file format is one lease per line, "<hw address> <IPv6>"
+        # When the 'autorefresh' argument is given, the plugin will try to refresh
+        # the lease mapping during runtime whenever the lease file is updated.
         - file: "leases.txt"
 
         # dns adds information about available DNS resolvers to the responses

+ 1 - 0
go.mod

@@ -4,6 +4,7 @@ go 1.13
 
 require (
 	github.com/chappjc/logrus-prefix v0.0.0-20180227015900-3a1d64819adb
+	github.com/fsnotify/fsnotify v1.4.9 // indirect
 	github.com/google/gopacket v1.1.19
 	github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 // indirect
 	github.com/insomniacslk/dhcp v0.0.0-20210120172423-cc9239ac6294

+ 67 - 5
plugins/file/plugin.go

@@ -18,10 +18,13 @@
 //  server6:
 //     ...
 //     plugins:
-//       - file: "file_leases.txt"
+//       - file: "file_leases.txt" [autorefresh]
 //     ...
 //
 // If the file path is not absolute, it is relative to the cwd where coredhcp is run.
+//
+// Optionally, when the 'autorefresh' argument is given, the plugin will try to refresh
+// the lease mapping during runtime whenever the lease file is updated.
 package file
 
 import (
@@ -31,15 +34,21 @@ import (
 	"io/ioutil"
 	"net"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/coredhcp/coredhcp/handler"
 	"github.com/coredhcp/coredhcp/logger"
 	"github.com/coredhcp/coredhcp/plugins"
+	"github.com/fsnotify/fsnotify"
 	"github.com/insomniacslk/dhcp/dhcpv4"
 	"github.com/insomniacslk/dhcp/dhcpv6"
 )
 
+const (
+	autoRefreshArg = "autorefresh"
+)
+
 var log = logger.GetLogger("plugins/file")
 
 // Plugin wraps plugin registration information
@@ -49,6 +58,8 @@ var Plugin = plugins.Plugin{
 	Setup4: setup4,
 }
 
+var recLock sync.RWMutex
+
 // StaticRecords holds a MAC -> IP address mapping
 var StaticRecords map[string]net.IP
 
@@ -145,6 +156,9 @@ func Handler6(req, resp dhcpv6.DHCPv6) (dhcpv6.DHCPv6, bool) {
 	}
 	log.Debugf("looking up an IP address for MAC %s", mac.String())
 
+	recLock.RLock()
+	defer recLock.RUnlock()
+
 	ipaddr, ok := StaticRecords[mac.String()]
 	if !ok {
 		log.Warningf("MAC address %s is unknown", mac.String())
@@ -167,6 +181,9 @@ func Handler6(req, resp dhcpv6.DHCPv6) (dhcpv6.DHCPv6, bool) {
 
 // Handler4 handles DHCPv4 packets for the file plugin
 func Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool) {
+	recLock.RLock()
+	defer recLock.RUnlock()
+
 	ipaddr, ok := StaticRecords[req.ClientHWAddr.String()]
 	if !ok {
 		log.Warningf("MAC address %s is unknown", req.ClientHWAddr.String())
@@ -189,7 +206,6 @@ func setup4(args ...string) (handler.Handler4, error) {
 
 func setupFile(v6 bool, args ...string) (handler.Handler6, handler.Handler4, error) {
 	var err error
-	var records map[string]net.IP
 	if len(args) < 1 {
 		return nil, nil, errors.New("need a file name")
 	}
@@ -197,6 +213,49 @@ func setupFile(v6 bool, args ...string) (handler.Handler6, handler.Handler4, err
 	if filename == "" {
 		return nil, nil, errors.New("got empty file name")
 	}
+
+	// load initial database from lease file
+	if err = loadFromFile(v6, filename); err != nil {
+		return nil, nil, err
+	}
+
+	// when the 'autorefresh' argument was passed, watch the lease file for
+	// changes and reload the lease mapping on any event
+	if len(args) > 1 && args[1] == autoRefreshArg {
+		// creates a new file watcher
+		watcher, err := fsnotify.NewWatcher()
+		if err != nil {
+			return nil, nil, fmt.Errorf("failed to create watcher: %w", err)
+		}
+
+		// have file watcher watch over lease file
+		if err = watcher.Add(filename); err != nil {
+			return nil, nil, fmt.Errorf("failed to watch %s: %w", filename, err)
+		}
+
+		// very simple watcher on the lease file to trigger a refresh on any event
+		// on the file
+		go func() {
+			for range watcher.Events {
+				err := loadFromFile(v6, filename)
+				if err != nil {
+					log.Warningf("failed to refresh from %s: %s", filename, err)
+
+					continue
+				}
+
+				log.Infof("updated to %d leases from %s", len(StaticRecords), filename)
+			}
+		}()
+	}
+
+	log.Infof("loaded %d leases from %s", len(StaticRecords), filename)
+	return Handler6, Handler4, nil
+}
+
+func loadFromFile(v6 bool, filename string) error {
+	var err error
+	var records map[string]net.IP
 	var protver int
 	if v6 {
 		protver = 6
@@ -206,10 +265,13 @@ func setupFile(v6 bool, args ...string) (handler.Handler6, handler.Handler4, err
 		records, err = LoadDHCPv4Records(filename)
 	}
 	if err != nil {
-		return nil, nil, fmt.Errorf("failed to load DHCPv%d records: %v", protver, err)
+		return fmt.Errorf("failed to load DHCPv%d records: %w", protver, err)
 	}
 
+	recLock.Lock()
+	defer recLock.Unlock()
+
 	StaticRecords = records
-	log.Infof("loaded %d leases from %s", len(records), filename)
-	return Handler6, Handler4, nil
+
+	return nil
 }

+ 33 - 10
plugins/file/plugin_test.go

@@ -9,6 +9,7 @@ import (
 	"net"
 	"os"
 	"testing"
+	"time"
 
 	"github.com/insomniacslk/dhcp/dhcpv4"
 	"github.com/insomniacslk/dhcp/dhcpv6"
@@ -319,16 +320,38 @@ func TestSetupFile(t *testing.T) {
 		os.Remove(tmp.Name())
 	}()
 
-	_, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
-	require.NoError(t, err)
-	_, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
-	require.NoError(t, err)
+	t.Run("typical case", func(t *testing.T) {
+		_, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
+		require.NoError(t, err)
+		_, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
+		require.NoError(t, err)
+
+		assert.Equal(t, 0, len(StaticRecords))
 
-	assert.Equal(t, 0, len(StaticRecords))
+		// leases should show up in StaticRecords
+		_, _, err = setupFile(true, tmp.Name())
+		if assert.NoError(t, err) {
+			assert.Equal(t, 2, len(StaticRecords))
+		}
+	})
 
-	// leases should show up in StaticRecords
-	_, _, err = setupFile(true, tmp.Name())
-	if assert.NoError(t, err) {
-		assert.Equal(t, 2, len(StaticRecords))
-	}
+	t.Run("autorefresh enabled", func(t *testing.T) {
+		_, _, err = setupFile(true, tmp.Name(), autoRefreshArg)
+		if assert.NoError(t, err) {
+			assert.Equal(t, 2, len(StaticRecords))
+		}
+		// we add more leases to the file
+		// this should trigger an event to refresh the leases database
+		// without calling setupFile again
+		_, err = tmp.WriteString("22:33:44:55:66:77 2001:db8::10:3\n")
+		require.NoError(t, err)
+		// since the event is processed asynchronously, give it a little time
+		time.Sleep(time.Millisecond * 100)
+		// an additional record should show up in the database
+		// but we should respect the locking first
+		recLock.RLock()
+		defer recLock.RUnlock()
+
+		assert.Equal(t, 3, len(StaticRecords))
+	})
 }