Sfoglia il codice sorgente

Unified config and plugin loading for v6 and v4

Andrea Barberio 7 anni fa
parent
commit
ad0749218a
3 ha cambiato i file con 116 aggiunte e 56 eliminazioni
  1. 2 2
      cmds/coredhcp/config.yml.example
  2. 59 31
      config/config.go
  3. 55 23
      coredhcp.go

+ 2 - 2
cmds/coredhcp/config.yml.example

@@ -5,5 +5,5 @@ server6:
         - file: "leases.txt"
         # - dns: 8.8.8.8 8.8.4.4 2001:4860:4860::8888 2001:4860:4860::8844
 
-#server4:
-#    listen: '127.0.0.1:67'
+server4:
+    listen: '127.0.0.1:67'

+ 59 - 31
config/config.go

@@ -1,7 +1,7 @@
 package config
 
 import (
-	"errors"
+	"fmt"
 	"net"
 	"strconv"
 	"strings"
@@ -13,6 +13,13 @@ import (
 
 var log = logger.GetLogger()
 
+type protocolVersion int
+
+const (
+	protocolV6 protocolVersion = 6
+	protocolV4 protocolVersion = 4
+)
+
 // Config holds the DHCPv6/v4 server configuration
 type Config struct {
 	v       *viper.Viper
@@ -51,10 +58,10 @@ func Load() (*Config, error) {
 	if err := c.v.ReadInConfig(); err != nil {
 		return nil, err
 	}
-	if err := c.parseV6Config(); err != nil {
+	if err := c.parseConfig(protocolV6); err != nil {
 		return nil, err
 	}
-	if err := c.parseV4Config(); err != nil {
+	if err := c.parseConfig(protocolV4); err != nil {
 		return nil, err
 	}
 	if c.Server6 == nil && c.Server4 == nil {
@@ -90,56 +97,77 @@ func parsePlugins(pluginList []interface{}) ([]*PluginConfig, error) {
 	return plugins, nil
 }
 
-func (c *Config) parseV6Config() error {
-	if exists := c.v.Get("server6"); exists == nil {
-		// it is valid to have no DHCPv6 configuration defined, so no
-		// server and no error are returned
-		return nil
+func (c *Config) getListenAddress(ver protocolVersion) (*net.UDPAddr, error) {
+	if exists := c.v.Get(fmt.Sprintf("server%d", ver)); exists == nil {
+		// it is valid to have no server configuration defined, and in this case
+		// no listening address and no error are returned.
+		return nil, nil
 	}
-	addr := c.v.GetString("server6.listen")
+	addr := c.v.GetString(fmt.Sprintf("server%d.listen", ver))
 	if addr == "" {
-		return ConfigErrorFromString("dhcpv6: missing `server6.listen` directive")
+		return nil, ConfigErrorFromString("dhcpv%d: missing `server%d.listen` directive", ver, ver)
 	}
 	ipStr, portStr, err := net.SplitHostPort(addr)
 	if err != nil {
-		return ConfigErrorFromString("dhcpv6: %v", err)
+		return nil, ConfigErrorFromString("dhcpv%d: %v", ver, err)
 	}
 	ip := net.ParseIP(ipStr)
-	if ip.To4() != nil {
-		return ConfigErrorFromString("dhcpv6: missing or invalid `listen` address")
+	if ip == nil {
+		return nil, ConfigErrorFromString("dhcpv%d: invalid IP address in `listen` directive", ver)
+	}
+	if ver == protocolV6 && ip.To4() != nil {
+		return nil, ConfigErrorFromString("dhcpv%d: not a valid IPv6 address in `listen` directive", ver)
+	} else if ver == protocolV4 && ip.To4() == nil {
+		return nil, ConfigErrorFromString("dhcpv%d: not a valid IPv4 address in `listen` directive", ver)
 	}
 	port, err := strconv.Atoi(portStr)
 	if err != nil {
-		return ConfigErrorFromString("dhcpv6: invalid `listen` port")
+		return nil, ConfigErrorFromString("dhcpv%d: invalid `listen` port", ver)
 	}
 	listener := net.UDPAddr{
 		IP:   ip,
 		Port: port,
 	}
-	sc := ServerConfig{
-		Listener: &listener,
-		Plugins:  nil,
-	}
-	// load plugins
-	pluginList := cast.ToSlice(c.v.Get("server6.plugins"))
+	return &listener, nil
+}
+
+func (c *Config) getPlugins(ver protocolVersion) ([]*PluginConfig, error) {
+	pluginList := cast.ToSlice(c.v.Get(fmt.Sprintf("server%d.plugins", ver)))
 	if pluginList == nil {
-		return ConfigErrorFromString("dhcpv6: invalid plugins section, not a list")
+		return nil, ConfigErrorFromString("dhcpv%d: invalid plugins section, not a list", ver)
+	}
+	return parsePlugins(pluginList)
+}
+
+func (c *Config) parseConfig(ver protocolVersion) error {
+	if ver != protocolV6 && ver != protocolV4 {
+		return ConfigErrorFromString("unknown protocol version: %d", ver)
+	}
+	listenAddr, err := c.getListenAddress(ver)
+	if err != nil {
+		return err
+	}
+	if listenAddr == nil {
+		// no listener is configured, so `c.Server6` (or `c.Server4` if v4)
+		// will stay nil.
+		return nil
 	}
-	plugins, err := parsePlugins(pluginList)
+	// read plugin configuration
+	plugins, err := c.getPlugins(ver)
 	if err != nil {
 		return err
 	}
 	for _, p := range plugins {
-		log.Printf("DHCPv6: found plugin `%s` with %d args: %v", p.Name, len(p.Args), p.Args)
+		log.Printf("DHCPv%d: found plugin `%s` with %d args: %v", ver, p.Name, len(p.Args), p.Args)
 	}
-	sc.Plugins = plugins
-	c.Server6 = &sc
-	return nil
-}
-
-func (c *Config) parseV4Config() error {
-	if exists := c.v.Get("server4"); exists != nil {
-		return errors.New("DHCPv4 config parser not implemented yet")
+	sc := ServerConfig{
+		Listener: listenAddr,
+		Plugins:  plugins,
+	}
+	if ver == protocolV6 {
+		c.Server6 = &sc
+	} else if ver == protocolV4 {
+		c.Server4 = &sc
 	}
 	return nil
 }

+ 55 - 23
coredhcp.go

@@ -29,39 +29,71 @@ type Server struct {
 // `plugins` section, in order. For a plugin to be available, it must have been
 // previously registered with plugins.RegisterPlugin. This is normally done at
 // plugin import time.
-func (s *Server) LoadPlugins(conf *config.Config) ([]*plugins.Plugin, error) {
+// This function returns the list of loaded v6 plugins, the list of loaded v4
+// plugins, and an error if any.
+func (s *Server) LoadPlugins(conf *config.Config) ([]*plugins.Plugin, []*plugins.Plugin, error) {
 	log.Print("Loading plugins...")
-	loadedPlugins := make([]*plugins.Plugin, 0)
+	loadedPlugins6 := make([]*plugins.Plugin, 0)
+	loadedPlugins4 := make([]*plugins.Plugin, 0)
 
-	if conf.Server4 != nil {
-		return nil, errors.New("plugin loading for DHCPv4 not implemented yet")
-	}
-	// load v6 plugins
-	if conf.Server6 == nil {
-		return nil, errors.New("no configuration found for DHCPv6 server")
+	if conf.Server6 == nil && conf.Server4 == nil {
+		return nil, nil, errors.New("no configuration found for either DHCPv6 or DHCPv4")
 	}
+
 	// now load the plugins. We need to call its setup function with
 	// the arguments extracted above. The setup function is mapped in
 	// plugins.RegisteredPlugins .
-	for _, pluginConf := range conf.Server6.Plugins {
-		if plugin, ok := plugins.RegisteredPlugins[pluginConf.Name]; ok {
-			log.Printf("Loading plugin `%s`", pluginConf.Name)
-			h6, err := plugin.Setup6(pluginConf.Args...)
-			if err != nil {
-				return nil, err
+
+	// Load DHCPv6 plugins.
+	if conf.Server6 != nil {
+		for _, pluginConf := range conf.Server6.Plugins {
+			if plugin, ok := plugins.RegisteredPlugins[pluginConf.Name]; ok {
+				log.Printf("DHCPv6: loading plugin `%s`", pluginConf.Name)
+				if plugin.Setup6 == nil {
+					log.Warningf("DHCPv6: plugin `%s` has no setup function for DHCPv6", pluginConf.Name)
+					continue
+				}
+				h6, err := plugin.Setup6(pluginConf.Args...)
+				if err != nil {
+					return nil, nil, err
+				}
+				loadedPlugins6 = append(loadedPlugins6, plugin)
+				if h6 == nil {
+					return nil, nil, config.ConfigErrorFromString("no DHCPv6 handler for plugin %s", pluginConf.Name)
+				}
+				s.Handlers6 = append(s.Handlers6, h6)
+			} else {
+				return nil, nil, config.ConfigErrorFromString("DHCPv6: unknown plugin `%s`", pluginConf.Name)
 			}
-			loadedPlugins = append(loadedPlugins, plugin)
-			if h6 == nil {
-				return nil, config.ConfigErrorFromString("no DHCPv6 handler for plugin %s", pluginConf.Name)
+		}
+	}
+	// Load DHCPv4 plugins. Yes, duplicated code, there's not really much that
+	// can be deduplicated here.
+	if conf.Server4 != nil {
+		for _, pluginConf := range conf.Server4.Plugins {
+			if plugin, ok := plugins.RegisteredPlugins[pluginConf.Name]; ok {
+				log.Printf("DHCPv4: loading plugin `%s`", pluginConf.Name)
+				if plugin.Setup4 == nil {
+					log.Warningf("DHCPv4: plugin `%s` has no setup function for DHCPv4", pluginConf.Name)
+					continue
+				}
+				h4, err := plugin.Setup4(pluginConf.Args...)
+				if err != nil {
+					return nil, nil, err
+				}
+				loadedPlugins4 = append(loadedPlugins4, plugin)
+				if h4 == nil {
+					return nil, nil, config.ConfigErrorFromString("no DHCPv4 handler for plugin %s", pluginConf.Name)
+				}
+				s.Handlers4 = append(s.Handlers4, h4)
+				//s.Handlers4 = append(s.Handlers4, h4)
+			} else {
+				return nil, nil, config.ConfigErrorFromString("DHCPv4: unknown plugin `%s`", pluginConf.Name)
 			}
-			s.Handlers6 = append(s.Handlers6, h6)
-			//s.Handlers4 = append(s.Handlers4, h4)
-		} else {
-			return nil, config.ConfigErrorFromString("unknown plugin `%s`", pluginConf.Name)
 		}
 	}
 
-	return loadedPlugins, nil
+	return loadedPlugins6, loadedPlugins4, nil
 }
 
 // MainHandler6 runs for every received DHCPv6 packet. It will run every
@@ -95,7 +127,7 @@ func (s *Server) MainHandler4(conn net.PacketConn, peer net.Addr, d *dhcpv4.DHCP
 // Start will start the server asynchronously. See `Wait` to wait until
 // the execution ends.
 func (s *Server) Start() error {
-	_, err := s.LoadPlugins(s.Config)
+	_, _, err := s.LoadPlugins(s.Config)
 	if err != nil {
 		return err
 	}