You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
4.8 KiB
216 lines
4.8 KiB
package zk |
|
|
|
import ( |
|
"fmt" |
|
"io" |
|
"io/ioutil" |
|
"math/rand" |
|
"os" |
|
"path/filepath" |
|
"strings" |
|
"time" |
|
) |
|
|
|
func init() { |
|
rand.Seed(time.Now().UnixNano()) |
|
} |
|
|
|
type TestServer struct { |
|
Port int |
|
Path string |
|
Srv *Server |
|
} |
|
|
|
type TestCluster struct { |
|
Path string |
|
Servers []TestServer |
|
} |
|
|
|
func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) { |
|
tmpPath, err := ioutil.TempDir("", "gozk") |
|
if err != nil { |
|
return nil, err |
|
} |
|
success := false |
|
startPort := int(rand.Int31n(6000) + 10000) |
|
cluster := &TestCluster{Path: tmpPath} |
|
defer func() { |
|
if !success { |
|
cluster.Stop() |
|
} |
|
}() |
|
for serverN := 0; serverN < size; serverN++ { |
|
srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN)) |
|
if err := os.Mkdir(srvPath, 0700); err != nil { |
|
return nil, err |
|
} |
|
port := startPort + serverN*3 |
|
cfg := ServerConfig{ |
|
ClientPort: port, |
|
DataDir: srvPath, |
|
} |
|
for i := 0; i < size; i++ { |
|
cfg.Servers = append(cfg.Servers, ServerConfigServer{ |
|
ID: i + 1, |
|
Host: "127.0.0.1", |
|
PeerPort: startPort + i*3 + 1, |
|
LeaderElectionPort: startPort + i*3 + 2, |
|
}) |
|
} |
|
cfgPath := filepath.Join(srvPath, "zoo.cfg") |
|
fi, err := os.Create(cfgPath) |
|
if err != nil { |
|
return nil, err |
|
} |
|
err = cfg.Marshall(fi) |
|
fi.Close() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
fi, err = os.Create(filepath.Join(srvPath, "myid")) |
|
if err != nil { |
|
return nil, err |
|
} |
|
_, err = fmt.Fprintf(fi, "%d\n", serverN+1) |
|
fi.Close() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
srv := &Server{ |
|
ConfigPath: cfgPath, |
|
Stdout: stdout, |
|
Stderr: stderr, |
|
} |
|
if err := srv.Start(); err != nil { |
|
return nil, err |
|
} |
|
cluster.Servers = append(cluster.Servers, TestServer{ |
|
Path: srvPath, |
|
Port: cfg.ClientPort, |
|
Srv: srv, |
|
}) |
|
} |
|
if err := cluster.waitForStart(10, time.Second); err != nil { |
|
return nil, err |
|
} |
|
success = true |
|
return cluster, nil |
|
} |
|
|
|
func (tc *TestCluster) Connect(idx int) (*Conn, error) { |
|
zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", tc.Servers[idx].Port)}, time.Second*15) |
|
return zk, err |
|
} |
|
|
|
func (tc *TestCluster) ConnectAll() (*Conn, <-chan Event, error) { |
|
return tc.ConnectAllTimeout(time.Second * 15) |
|
} |
|
|
|
func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) { |
|
return tc.ConnectWithOptions(sessionTimeout) |
|
} |
|
|
|
func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { |
|
hosts := make([]string, len(tc.Servers)) |
|
for i, srv := range tc.Servers { |
|
hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) |
|
} |
|
zk, ch, err := Connect(hosts, sessionTimeout, options...) |
|
return zk, ch, err |
|
} |
|
|
|
func (tc *TestCluster) Stop() error { |
|
for _, srv := range tc.Servers { |
|
srv.Srv.Stop() |
|
} |
|
defer os.RemoveAll(tc.Path) |
|
return tc.waitForStop(5, time.Second) |
|
} |
|
|
|
// waitForStart blocks until the cluster is up |
|
func (tc *TestCluster) waitForStart(maxRetry int, interval time.Duration) error { |
|
// verify that the servers are up with SRVR |
|
serverAddrs := make([]string, len(tc.Servers)) |
|
for i, s := range tc.Servers { |
|
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) |
|
} |
|
|
|
for i := 0; i < maxRetry; i++ { |
|
_, ok := FLWSrvr(serverAddrs, time.Second) |
|
if ok { |
|
return nil |
|
} |
|
time.Sleep(interval) |
|
} |
|
return fmt.Errorf("unable to verify health of servers") |
|
} |
|
|
|
// waitForStop blocks until the cluster is down |
|
func (tc *TestCluster) waitForStop(maxRetry int, interval time.Duration) error { |
|
// verify that the servers are up with RUOK |
|
serverAddrs := make([]string, len(tc.Servers)) |
|
for i, s := range tc.Servers { |
|
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) |
|
} |
|
|
|
var success bool |
|
for i := 0; i < maxRetry && !success; i++ { |
|
success = true |
|
for _, ok := range FLWRuok(serverAddrs, time.Second) { |
|
if ok { |
|
success = false |
|
} |
|
} |
|
if !success { |
|
time.Sleep(interval) |
|
} |
|
} |
|
if !success { |
|
return fmt.Errorf("unable to verify servers are down") |
|
} |
|
return nil |
|
} |
|
|
|
func (tc *TestCluster) StartServer(server string) { |
|
for _, s := range tc.Servers { |
|
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { |
|
s.Srv.Start() |
|
return |
|
} |
|
} |
|
panic(fmt.Sprintf("Unknown server: %s", server)) |
|
} |
|
|
|
func (tc *TestCluster) StopServer(server string) { |
|
for _, s := range tc.Servers { |
|
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { |
|
s.Srv.Stop() |
|
return |
|
} |
|
} |
|
panic(fmt.Sprintf("Unknown server: %s", server)) |
|
} |
|
|
|
func (tc *TestCluster) StartAllServers() error { |
|
for _, s := range tc.Servers { |
|
if err := s.Srv.Start(); err != nil { |
|
return fmt.Errorf( |
|
"Failed to start server listening on port `%d` : %+v", s.Port, err) |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (tc *TestCluster) StopAllServers() error { |
|
for _, s := range tc.Servers { |
|
if err := s.Srv.Stop(); err != nil { |
|
return fmt.Errorf( |
|
"Failed to stop server listening on port `%d` : %+v", s.Port, err) |
|
} |
|
} |
|
|
|
return nil |
|
}
|
|
|