Przeglądaj źródła

Add DHCP RELEASE message support (#266)

Implement DHCP RELEASE message handling in the range plugin

Signed-off-by: Nikita Vakula <programmistov.programmist@gmail.com>
Nikita Vakula 1 miesiąc temu
rodzic
commit
da62c7b1bd

+ 1 - 0
go.mod

@@ -35,6 +35,7 @@ require (
 	github.com/sagikazarmark/locafero v0.7.0 // indirect
 	github.com/sourcegraph/conc v0.3.0 // indirect
 	github.com/spf13/afero v1.12.0 // indirect
+	github.com/stretchr/objx v0.5.2 // indirect
 	github.com/subosito/gotenv v1.6.0 // indirect
 	github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
 	github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect

+ 25 - 0
plugins/range/plugin.go

@@ -53,6 +53,11 @@ func (p *PluginState) Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool)
 	defer p.Unlock()
 	record, ok := p.Recordsv4[req.ClientHWAddr.String()]
 	hostname := req.HostName()
+
+	if ok && req.MessageType() == dhcpv4.MessageTypeRelease {
+		return p.handleRelease(req, resp, record)
+	}
+
 	if !ok {
 		// Allocating new address since there isn't one allocated
 		log.Printf("MAC address %s is new, leasing new IPv4 address", req.ClientHWAddr.String())
@@ -90,6 +95,26 @@ func (p *PluginState) Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool)
 	return resp, false
 }
 
+func (p *PluginState) handleRelease(req, _ *dhcpv4.DHCPv4, record *Record) (*dhcpv4.DHCPv4, bool) {
+	// Remove lease from storage
+	if freeErr := p.freeIPAddress(req.ClientHWAddr, record); freeErr != nil {
+		log.Errorf("Could not remove lease from storage for MAC %s: %v", req.ClientHWAddr.String(), freeErr)
+		return nil, true
+	}
+
+	// Remove from in-memory map
+	delete(p.Recordsv4, req.ClientHWAddr.String())
+
+	// Release the IP address from allocator
+	if freeErr := p.allocator.Free(net.IPNet{IP: record.IP}); freeErr != nil {
+		log.Errorf("Could not free IP %s for MAC %s: %v", record.IP.String(), req.ClientHWAddr.String(), freeErr)
+		return nil, true
+	}
+
+	log.Printf("Released IP address %s for MAC %s", record.IP.String(), req.ClientHWAddr.String())
+	return nil, true
+}
+
 func setupRange(args ...string) (handler.Handler4, error) {
 	var (
 		err error

+ 192 - 0
plugins/range/plugin_test.go

@@ -0,0 +1,192 @@
+package rangeplugin
+
+import (
+	"fmt"
+	"net"
+	"testing"
+
+	"github.com/insomniacslk/dhcp/dhcpv4"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/mock"
+)
+
+// mockAllocator is a simple mock for testing
+type mockAllocator struct {
+	mock.Mock
+}
+
+func (m *mockAllocator) Allocate(hint net.IPNet) (net.IPNet, error) {
+	return m.Called(hint).Get(0).(net.IPNet), nil
+}
+
+func (m *mockAllocator) Free(ip net.IPNet) error {
+	m.Called(ip)
+	return nil
+}
+
+type mockFailingAllocator struct {
+	mock.Mock
+}
+
+func (m *mockFailingAllocator) Allocate(hint net.IPNet) (net.IPNet, error) {
+	args := m.Called(hint)
+	return args.Get(0).(net.IPNet), args.Error(1)
+}
+
+func (m *mockFailingAllocator) Free(ip net.IPNet) error {
+	args := m.Called(ip)
+	return args.Error(0)
+}
+
+func TestHandler4Release(t *testing.T) {
+	db, dbErr := testDBSetup()
+	if dbErr != nil {
+		t.Fatalf("Failed to set up test DB: %v", dbErr)
+	}
+
+	mockAlloc := &mockAllocator{}
+
+	pl := PluginState{
+		leasedb:   db,
+		Recordsv4: make(map[string]*Record),
+		allocator: mockAlloc,
+	}
+
+	loadedRecords, loadErr := loadRecords(db)
+	if loadErr != nil {
+		t.Fatalf("Failed to load records: %v", loadErr)
+	}
+	pl.Recordsv4 = loadedRecords
+
+	// Create a DHCP RELEASE request using existing test data
+	hwaddr, _ := net.ParseMAC(records[1].mac)
+	req := &dhcpv4.DHCPv4{
+		ClientHWAddr: hwaddr,
+	}
+	req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRelease))
+
+	resp := &dhcpv4.DHCPv4{}
+
+	// Verify record exists before release
+	record, exists := pl.Recordsv4[hwaddr.String()]
+	assert.True(t, exists, "Record should exist before release")
+
+	expectedIPNet := net.IPNet{IP: record.IP}
+	mockAlloc.On("Free", expectedIPNet).Return(nil)
+
+	// Call Handler4 with RELEASE message
+	result, stop := pl.Handler4(req, resp)
+
+	assert.Nil(t, result, "Should return nil response for RELEASE")
+	assert.True(t, stop, "Should return true to stop processing")
+
+	_, exists = pl.Recordsv4[hwaddr.String()]
+	assert.False(t, exists, "Record should be removed from memory after release")
+
+	parsedRecords, parseErr := loadRecords(pl.leasedb)
+	if parseErr != nil {
+		t.Fatalf("Failed to load records after release: %v", parseErr)
+	}
+	_, exists = parsedRecords[hwaddr.String()]
+	assert.False(t, exists, "Record should be removed from storage after release")
+
+	mockAlloc.AssertExpectations(t)
+	mockAlloc.AssertNotCalled(t, "Allocate")
+}
+
+func TestHandler4ReleaseAllocatorError(t *testing.T) {
+	db, parseErr := testDBSetup()
+	if parseErr != nil {
+		t.Fatalf("Failed to set up test DB: %v", parseErr)
+	}
+
+	mockAlloc := &mockFailingAllocator{}
+
+	pl := PluginState{
+		leasedb:   db,
+		Recordsv4: make(map[string]*Record),
+		allocator: mockAlloc,
+	}
+
+	loadedRecords, err := loadRecords(db)
+	if err != nil {
+		t.Fatalf("Failed to load records: %v", err)
+	}
+	pl.Recordsv4 = loadedRecords
+
+	hwaddr, _ := net.ParseMAC(records[1].mac)
+	req := &dhcpv4.DHCPv4{
+		ClientHWAddr: hwaddr,
+	}
+	req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRelease))
+
+	resp := &dhcpv4.DHCPv4{}
+
+	record := pl.Recordsv4[hwaddr.String()]
+	expectedIPNet := net.IPNet{IP: record.IP}
+
+	expectedError := fmt.Errorf("mock allocator free failure")
+	mockAlloc.On("Free", expectedIPNet).Return(expectedError)
+
+	// Call Handler4 - this should fail on allocator.Free()
+	result, stop := pl.Handler4(req, resp)
+
+	assert.Nil(t, result, "Should return nil on allocator failure")
+	assert.True(t, stop, "Should stop processing on allocator failure")
+
+	_, exists := pl.Recordsv4[hwaddr.String()]
+	assert.False(t, exists, "Record should be removed from memory even on allocator failure")
+
+	parsedRecords, parseErr := loadRecords(pl.leasedb)
+	if parseErr != nil {
+		t.Fatalf("Failed to load records after release: %v", parseErr)
+	}
+	_, exists = parsedRecords[hwaddr.String()]
+	assert.False(t, exists, "Record should be removed from storage even on allocator failure")
+
+	mockAlloc.AssertExpectations(t)
+	mockAlloc.AssertNotCalled(t, "Allocate")
+}
+
+func TestHandler4ReleaseStorageError(t *testing.T) {
+	db, parseErr := testDBSetup()
+	if parseErr != nil {
+		t.Fatalf("Failed to set up test DB: %v", parseErr)
+	}
+
+	mockAlloc := &mockAllocator{}
+
+	pl := PluginState{
+		leasedb:   db,
+		Recordsv4: make(map[string]*Record),
+		allocator: mockAlloc,
+	}
+
+	loadedRecords, err := loadRecords(db)
+	if err != nil {
+		t.Fatalf("Failed to load records: %v", err)
+	}
+	pl.Recordsv4 = loadedRecords
+
+	hwaddr, _ := net.ParseMAC(records[1].mac)
+	req := &dhcpv4.DHCPv4{
+		ClientHWAddr: hwaddr,
+	}
+	req.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeRelease))
+
+	resp := &dhcpv4.DHCPv4{}
+
+	// Close the database to simulate storage failure
+	db.Close()
+
+	result, stop := pl.Handler4(req, resp)
+
+	assert.Nil(t, result, "Should return nil on storage failure")
+	assert.True(t, stop, "Should stop processing on storage failure")
+
+	_, exists := pl.Recordsv4[hwaddr.String()]
+	assert.True(t, exists, "Record should still exist in memory after storage failure")
+
+	mockAlloc.AssertNotCalled(t, "Free")
+	mockAlloc.AssertNotCalled(t, "Allocate")
+}

+ 16 - 0
plugins/range/storage.go

@@ -76,6 +76,22 @@ func (p *PluginState) saveIPAddress(mac net.HardwareAddr, record *Record) error
 	return nil
 }
 
+// freeIPAddress removes a lease from storage
+func (p *PluginState) freeIPAddress(mac net.HardwareAddr, record *Record) error {
+	stmt, err := p.leasedb.Prepare(`delete from leases4 where mac = ? and ip = ?`)
+	if err != nil {
+		return fmt.Errorf("statement preparation failed: %w", err)
+	}
+	defer stmt.Close()
+	if _, err := stmt.Exec(
+		mac.String(),
+		record.IP.String(),
+	); err != nil {
+		return fmt.Errorf("record delete failed: %w", err)
+	}
+	return nil
+}
+
 // registerBackingDB installs a database connection string as the backing store for leases
 func (p *PluginState) registerBackingDB(filename string) error {
 	if p.leasedb != nil {

+ 142 - 0
plugins/range/storage_test.go

@@ -97,3 +97,145 @@ func TestWriteRecords(t *testing.T) {
 
 	assert.Equal(t, mapRec, parsedRec, "Loaded records differ from what's in the DB")
 }
+
+func TestFreeIPAddress(t *testing.T) {
+	db, err := testDBSetup()
+	if err != nil {
+		t.Fatalf("Failed to set up test DB: %v", err)
+	}
+
+	pl := PluginState{leasedb: db}
+
+	hwaddr, err := net.ParseMAC(records[1].mac)
+	if err != nil {
+		t.Fatalf("Failed to parse MAC address: %v", err)
+	}
+
+	record := records[1].ip
+
+	parsedRecords, err := loadRecords(pl.leasedb)
+	if err != nil {
+		t.Fatalf("Failed to load records: %v", err)
+	}
+	_, exists := parsedRecords[hwaddr.String()]
+	assert.True(t, exists, "Record should exist before deletion")
+
+	// Now free the IP address
+	if err := pl.freeIPAddress(hwaddr, record); err != nil {
+		t.Errorf("Failed to free IP address: %v", err)
+	}
+
+	parsedRecords, err = loadRecords(pl.leasedb)
+	if err != nil {
+		t.Fatalf("Failed to load records after deletion: %v", err)
+	}
+	_, exists = parsedRecords[hwaddr.String()]
+	assert.False(t, exists, "Record should not exist after deletion")
+}
+
+func TestFreeIPAddressNonExistent(t *testing.T) {
+	pl := PluginState{}
+	if err := pl.registerBackingDB(":memory:"); err != nil {
+		t.Fatalf("Could not setup file")
+	}
+
+	hwaddr, err := net.ParseMAC("02:00:00:00:00:99")
+	if err != nil {
+		t.Fatalf("Failed to parse MAC address: %v", err)
+	}
+
+	record := &Record{
+		IP:       net.IPv4(10, 0, 0, 99),
+		expires:  expire,
+		hostname: "non-existent",
+	}
+
+	err = pl.freeIPAddress(hwaddr, record)
+	assert.NoError(t, err, "Freeing a non-existent IP address should not return an error")
+
+	parsedRecords, err := loadRecords(pl.leasedb)
+	if err != nil {
+		t.Fatalf("Failed to load records: %v", err)
+	}
+	assert.Empty(t, parsedRecords, "Database should be empty")
+}
+
+func TestFreeIPAddressVerifyDeletion(t *testing.T) {
+	db, err := testDBSetup()
+	if err != nil {
+		t.Fatalf("Failed to set up test DB: %v", err)
+	}
+
+	pl := PluginState{leasedb: db}
+
+	parsedRecords, err := loadRecords(pl.leasedb)
+	if err != nil {
+		t.Fatalf("Failed to load records: %v", err)
+	}
+	assert.Len(t, parsedRecords, 6, "Should have 6 records from testDBSetup")
+
+	// Delete the middle record (records[2] = "02:00:00:00:00:02" with IP 10.0.0.2)
+	hwaddrToDelete, _ := net.ParseMAC(records[2].mac)
+	recordToDelete := records[2].ip
+
+	if err := pl.freeIPAddress(hwaddrToDelete, recordToDelete); err != nil {
+		t.Errorf("Failed to free IP address: %v", err)
+	}
+
+	parsedRecords, err = loadRecords(pl.leasedb)
+	if err != nil {
+		t.Fatalf("Failed to load records after deletion: %v", err)
+	}
+
+	assert.Len(t, parsedRecords, 5, "Should have 5 records after deletion")
+	_, exists := parsedRecords[hwaddrToDelete.String()]
+	assert.False(t, exists, "Deleted record should not exist")
+
+	// Verify some other records still exist
+	otherMacs := []string{records[1].mac, records[3].mac}
+	for _, mac := range otherMacs {
+		_, exists := parsedRecords[mac]
+		assert.True(t, exists, "Other records should still exist: %s", mac)
+	}
+}
+
+func TestFreeIPAddressExecutionError(t *testing.T) {
+	// This test triggers a statement execution failure using a SQLite trigger
+	// that aborts DELETE operations for records[0]
+
+	db, err := testDBSetup()
+	if err != nil {
+		t.Fatalf("Failed to set up test database: %v", err)
+	}
+	defer db.Close()
+
+	const triggerErrorMsg = "Custom deletion prevention trigger"
+	// Create a trigger that will cause DELETE operations to fail for records[0]
+	triggerSQL := fmt.Sprintf(`
+		CREATE TRIGGER prevent_delete
+		BEFORE DELETE ON leases4
+		WHEN OLD.mac = '%s'
+		BEGIN
+			SELECT RAISE(ABORT, '%s');
+		END
+	`, records[0].mac, triggerErrorMsg)
+	_, err = db.Exec(triggerSQL)
+	if err != nil {
+		t.Fatalf("Failed to create trigger: %v", err)
+	}
+
+	pl := PluginState{leasedb: db}
+
+	hwaddr, err := net.ParseMAC(records[0].mac)
+	if err != nil {
+		t.Fatalf("Failed to parse MAC address: %v", err)
+	}
+
+	record := records[0].ip
+
+	err = pl.freeIPAddress(hwaddr, record)
+
+	assert.Error(t, err, "Should return error due to trigger preventing deletion")
+	assert.Contains(t, err.Error(), "record delete failed", "Error should indicate record delete failure")
+	assert.Contains(t, err.Error(), triggerErrorMsg, "Error should contain trigger message")
+}

+ 1 - 0
server/handle.go

@@ -127,6 +127,7 @@ func (l *listener4) HandleMsg4(buf []byte, oob *ipv4.ControlMessage, _peer net.A
 		tmp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
 	case dhcpv4.MessageTypeRequest:
 		tmp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
+	case dhcpv4.MessageTypeRelease:
 	default:
 		log.Printf("plugins/server: Unhandled message type: %v", mt)
 		return