Commit 49f40a95 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

Add some handlers tests

parent dee5a19d
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/oauth2"
"github.com/gorilla/sessions"
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/auth"
"github.com/nsheridan/cashier/server/auth/testprovider"
"github.com/nsheridan/cashier/server/config"
"github.com/nsheridan/cashier/server/signer"
"github.com/nsheridan/cashier/server/store"
"github.com/nsheridan/cashier/testdata"
)
func newContext(t *testing.T) *appContext {
f, err := ioutil.TempFile(os.TempDir(), "signing_key_")
if err != nil {
t.Error(err)
}
defer os.Remove(f.Name())
f.Write(testdata.Priv)
f.Close()
signer, err := signer.New(&config.SSH{
SigningKey: f.Name(),
MaxAge: "1h",
})
if err != nil {
t.Error(err)
}
return &appContext{
cookiestore: sessions.NewCookieStore([]byte("secret")),
authprovider: testprovider.New(),
certstore: store.NewMemoryStore(),
authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
sshKeySigner: signer,
}
}
func TestLoginHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/login", nil)
resp := httptest.NewRecorder()
loginHandler(newContext(t), resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
t.Error("Unexpected response")
}
}
func TestCallbackHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/callback", nil)
req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
resp := httptest.NewRecorder()
ctx := newContext(t)
ctx.setAuthStateCookie(resp, req, "state")
callbackHandler(ctx, resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
t.Error("Unexpected response")
}
}
func TestRootHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := newContext(t)
tok := &oauth2.Token{
AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
Expiry: time.Now().Add(1 * time.Hour),
}
ctx.setAuthTokenCookie(resp, req, tok)
rootHandler(ctx, resp, req)
if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
t.Error("Unable to find token in response")
}
}
func TestRootHandlerNoSession(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := newContext(t)
rootHandler(ctx, resp, req)
if resp.Code != http.StatusSeeOther {
t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
}
}
func TestSignRevoke(t *testing.T) {
s, _ := json.Marshal(&lib.SignRequest{
Key: string(testdata.Pub),
})
req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
resp := httptest.NewRecorder()
ctx := newContext(t)
req.Header.Set("Authorization", "Bearer abcdef")
signHandler(ctx, resp, req)
if resp.Code != http.StatusOK {
t.Error("Unexpected response")
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
r := &lib.SignResponse{}
if err := json.Unmarshal(b, r); err != nil {
t.Error(err)
}
if r.Status != "ok" {
t.Error("Unexpected response")
}
k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response))
if err != nil {
t.Error(err)
}
cert, ok := k.(*ssh.Certificate)
if !ok {
t.Error("Did not receive a certificate")
}
// Revoke the cert and verify
req, _ = http.NewRequest("POST", "/revoke", nil)
req.Form = url.Values{"cert_id": []string{cert.KeyId}}
revokeCertHandler(ctx, resp, req)
req, _ = http.NewRequest("GET", "/revoked", nil)
revokedCertsHandler(ctx, resp, req)
revoked, _ := ioutil.ReadAll(resp.Body)
if string(revoked[:len(revoked)-1]) != r.Response {
t.Error("omg")
}
}
......@@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// Sign the pubkey and issue the cert.
req, err := parseKey(r)
req.Principal = a.authprovider.Username(token)
a.authprovider.Revoke(token) // We don't need this anymore.
if err != nil {
return http.StatusInternalServerError, err
}
req.Principal = a.authprovider.Username(token)
a.authprovider.Revoke(token) // We don't need this anymore.
cert, err := a.sshKeySigner.SignUserKey(req)
if err != nil {
return http.StatusInternalServerError, err
......@@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request)
}
func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
if r.Method == "GET" {
return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed))
}
r.ParseForm()
id := r.FormValue("cert_id")
if id == "" {
......@@ -268,7 +265,7 @@ func main() {
log.Fatal(err)
}
fs.Register(&config.AWS)
signer, err := signer.New(config.SSH)
signer, err := signer.New(&config.SSH)
if err != nil {
log.Fatal(err)
}
......@@ -304,12 +301,12 @@ func main() {
}
r := mux.NewRouter()
r.Handle("/", appHandler{ctx, rootHandler})
r.Handle("/auth/login", appHandler{ctx, loginHandler})
r.Handle("/auth/callback", appHandler{ctx, callbackHandler})
r.Handle("/sign", appHandler{ctx, signHandler})
r.Handle("/revoked", appHandler{ctx, revokedCertsHandler})
r.Handle("/revoke", appHandler{ctx, revokeCertHandler})
r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler})
r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler})
logfile := os.Stderr
if config.Server.HTTPLogFile != "" {
logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
......
package testprovider
import (
"time"
"github.com/nsheridan/cashier/server/auth"
"golang.org/x/oauth2"
)
const (
name = "testprovider"
)
// Config is an implementation of `auth.Provider` for testing.
type Config struct{}
// New creates a new provider.
func New() auth.Provider {
return &Config{}
}
// Name returns the name of the provider.
func (c *Config) Name() string {
return name
}
// Valid validates the oauth token.
func (c *Config) Valid(token *oauth2.Token) bool {
return true
}
// Revoke disables the access token.
func (c *Config) Revoke(token *oauth2.Token) error {
return nil
}
// StartSession retrieves an authentication endpoint.
func (c *Config) StartSession(state string) *auth.Session {
return &auth.Session{
AuthURL: "https://www.example.com/auth",
}
}
// Exchange authorizes the session and returns an access token.
func (c *Config) Exchange(code string) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "token",
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
// Username retrieves the username portion of the user's email address.
func (c *Config) Username(token *oauth2.Token) string {
return "test"
}
......@@ -69,7 +69,7 @@ func makeperms(perms []string) map[string]string {
}
// New creates a new KeySigner from the supplied configuration.
func New(conf config.SSH) (*KeySigner, error) {
func New(conf *config.SSH) (*KeySigner, error) {
data, err := wkfs.ReadFile(conf.SigningKey)
if err != nil {
return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment