Răsfoiți Sursa

plugins/range: added hostname field to database

Signed-off-by: Andrea Barberio <insomniac@slackware.it>
Andrea Barberio 1 an în urmă
părinte
comite
4458e77f16
3 a modificat fișierele cu 24 adăugiri și 19 ștergeri
  1. 4 0
      plugins/range/plugin.go
  2. 7 6
      plugins/range/storage.go
  3. 13 13
      plugins/range/storage_test.go

+ 4 - 0
plugins/range/plugin.go

@@ -33,6 +33,7 @@ var Plugin = plugins.Plugin{
 type Record struct {
 	IP      net.IP
 	expires int
+	hostname string
 }
 
 // PluginState is the data held by an instance of the range plugin
@@ -51,6 +52,7 @@ func (p *PluginState) Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool)
 	p.Lock()
 	defer p.Unlock()
 	record, ok := p.Recordsv4[req.ClientHWAddr.String()]
+	hostname := req.HostName()
 	if !ok {
 		// Allocating new address since there isn't one allocated
 		log.Printf("MAC address %s is new, leasing new IPv4 address", req.ClientHWAddr.String())
@@ -62,6 +64,7 @@ func (p *PluginState) Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool)
 		rec := Record{
 			IP:      ip.IP.To4(),
 			expires: int(time.Now().Add(p.LeaseTime).Unix()),
+			hostname: hostname,
 		}
 		err = p.saveIPAddress(req.ClientHWAddr, &rec)
 		if err != nil {
@@ -74,6 +77,7 @@ func (p *PluginState) Handler4(req, resp *dhcpv4.DHCPv4) (*dhcpv4.DHCPv4, bool)
 		expiry := time.Unix(int64(record.expires), 0)
 		if expiry.Before(time.Now().Add(p.LeaseTime)) {
 			record.expires = int(time.Now().Add(p.LeaseTime).Round(time.Second).Unix())
+			record.hostname = hostname
 			err := p.saveIPAddress(req.ClientHWAddr, record)
 			if err != nil {
 				log.Errorf("Could not persist lease for MAC %s: %v", req.ClientHWAddr.String(), err)

+ 7 - 6
plugins/range/storage.go

@@ -18,7 +18,7 @@ func loadDB(path string) (*sql.DB, error) {
 	if err != nil {
 		return nil, fmt.Errorf("failed to open database (%T): %w", err, err)
 	}
-	if _, err := db.Exec("create table if not exists leases4 (mac string not null, ip string not null, expiry int, primary key (mac, ip))"); err != nil {
+	if _, err := db.Exec("create table if not exists leases4 (mac string not null, ip string not null, expiry int, hostname string not null, primary key (mac, ip))"); err != nil {
 		return nil, fmt.Errorf("table creation failed: %w", err)
 	}
 	return db, nil
@@ -28,18 +28,18 @@ func loadDB(path string) (*sql.DB, error) {
 // the specified file. The records have to be one per line, a mac address and an
 // IP address.
 func loadRecords(db *sql.DB) (map[string]*Record, error) {
-	rows, err := db.Query("select mac, ip, expiry from leases4")
+	rows, err := db.Query("select mac, ip, expiry, hostname from leases4")
 	if err != nil {
 		return nil, fmt.Errorf("failed to query leases database: %w", err)
 	}
 	defer rows.Close()
 	var (
-		mac, ip string
+		mac, ip, hostname string
 		expiry  int
 		records = make(map[string]*Record)
 	)
 	for rows.Next() {
-		if err := rows.Scan(&mac, &ip, &expiry); err != nil {
+		if err := rows.Scan(&mac, &ip, &expiry, &hostname); err != nil {
 			return nil, fmt.Errorf("failed to scan row: %w", err)
 		}
 		hwaddr, err := net.ParseMAC(mac)
@@ -50,7 +50,7 @@ func loadRecords(db *sql.DB) (map[string]*Record, error) {
 		if ipaddr.To4() == nil {
 			return nil, fmt.Errorf("expected an IPv4 address, got: %v", ipaddr)
 		}
-		records[hwaddr.String()] = &Record{IP: ipaddr, expires: expiry}
+		records[hwaddr.String()] = &Record{IP: ipaddr, expires: expiry, hostname: hostname}
 	}
 	if err := rows.Err(); err != nil {
 		return nil, fmt.Errorf("failed lease database row scanning: %w", err)
@@ -60,7 +60,7 @@ func loadRecords(db *sql.DB) (map[string]*Record, error) {
 
 // saveIPAddress writes out a lease to storage
 func (p *PluginState) saveIPAddress(mac net.HardwareAddr, record *Record) error {
-	stmt, err := p.leasedb.Prepare(`insert or replace into leases4(mac, ip, expiry) values (?, ?, ?)`)
+	stmt, err := p.leasedb.Prepare(`insert or replace into leases4(mac, ip, expiry, hostname) values (?, ?, ?, ?)`)
 	if err != nil {
 		return fmt.Errorf("statement preparation failed: %w", err)
 	}
@@ -68,6 +68,7 @@ func (p *PluginState) saveIPAddress(mac net.HardwareAddr, record *Record) error
 		mac.String(),
 		record.IP.String(),
 		record.expires,
+		record.hostname,
 	); err != nil {
 		return fmt.Errorf("record insert/update failed: %w", err)
 	}

+ 13 - 13
plugins/range/storage_test.go

@@ -20,12 +20,12 @@ func testDBSetup() (*sql.DB, error) {
 		return nil, err
 	}
 	for _, record := range records {
-		stmt, err := db.Prepare("insert into leases4(mac, ip, expiry) values (?, ?, ?)")
+		stmt, err := db.Prepare("insert into leases4(mac, ip, expiry, hostname) values (?, ?, ?, ?)")
 		if err != nil {
 			return nil, fmt.Errorf("failed to prepare insert statement: %w", err)
 		}
 		defer stmt.Close()
-		if _, err := stmt.Exec(record.mac, record.ip.IP.String(), record.ip.expires); err != nil {
+		if _, err := stmt.Exec(record.mac, record.ip.IP.String(), record.ip.expires, record.ip.hostname); err != nil {
 			return nil, fmt.Errorf("failed to insert record into test db: %w", err)
 		}
 	}
@@ -37,12 +37,12 @@ var records = []struct {
 	mac string
 	ip  *Record
 }{
-	{"02:00:00:00:00:00", &Record{net.IPv4(10, 0, 0, 0), expire}},
-	{"02:00:00:00:00:01", &Record{net.IPv4(10, 0, 0, 1), expire}},
-	{"02:00:00:00:00:02", &Record{net.IPv4(10, 0, 0, 2), expire}},
-	{"02:00:00:00:00:03", &Record{net.IPv4(10, 0, 0, 3), expire}},
-	{"02:00:00:00:00:04", &Record{net.IPv4(10, 0, 0, 4), expire}},
-	{"02:00:00:00:00:05", &Record{net.IPv4(10, 0, 0, 5), expire}},
+	{"02:00:00:00:00:00", &Record{IP: net.IPv4(10, 0, 0, 0), expires: expire, hostname: "zero"}},
+	{"02:00:00:00:00:01", &Record{IP: net.IPv4(10, 0, 0, 1), expires: expire, hostname: "one"}},
+	{"02:00:00:00:00:02", &Record{IP: net.IPv4(10, 0, 0, 2), expires: expire, hostname: "two"}},
+	{"02:00:00:00:00:03", &Record{IP: net.IPv4(10, 0, 0, 3), expires: expire, hostname: "three"}},
+	{"02:00:00:00:00:04", &Record{IP: net.IPv4(10, 0, 0, 4), expires: expire, hostname: "four"}},
+	{"02:00:00:00:00:05", &Record{IP: net.IPv4(10, 0, 0, 5), expires: expire, hostname: "five"}},
 }
 
 func TestLoadRecords(t *testing.T) {
@@ -59,13 +59,13 @@ func TestLoadRecords(t *testing.T) {
 	mapRec := make(map[string]*Record)
 	for _, rec := range records {
 		var (
-			ip, mac string
-			expiry  int
+			ip, mac, hostname string
+			expiry            int
 		)
-		if err := db.QueryRow("select mac, ip, expiry from leases4 where mac = ?", rec.mac).Scan(&mac, &ip, &expiry); err != nil {
+		if err := db.QueryRow("select mac, ip, expiry, hostname from leases4 where mac = ?", rec.mac).Scan(&mac, &ip, &expiry, &hostname); err != nil {
 			t.Fatalf("record not found for mac=%s: %v", rec.mac, err)
 		}
-		mapRec[mac] = &Record{IP: net.ParseIP(ip), expires: expiry}
+		mapRec[mac] = &Record{IP: net.ParseIP(ip), expires: expiry, hostname: hostname}
 	}
 
 	assert.Equal(t, mapRec, parsedRec, "Loaded records differ from what's in the DB")
@@ -87,7 +87,7 @@ func TestWriteRecords(t *testing.T) {
 		if err := pl.saveIPAddress(hwaddr, rec.ip); err != nil {
 			t.Errorf("Failed to save ip for %s: %v", hwaddr, err)
 		}
-		mapRec[hwaddr.String()] = &Record{IP: rec.ip.IP, expires: rec.ip.expires}
+		mapRec[hwaddr.String()] = &Record{IP: rec.ip.IP, expires: rec.ip.expires, hostname: rec.ip.hostname}
 	}
 
 	parsedRec, err := loadRecords(pl.leasedb)