Commit a5602dd8 authored by Niall Sheridan's avatar Niall Sheridan
Browse files

List only certs which haven't expired

parent 7dbcbcc7
......@@ -39,7 +39,9 @@ func (ms *memoryStore) List() ([]*CertRecord, error) {
ms.Lock()
defer ms.Unlock()
for _, value := range ms.certs {
records = append(records, value)
if value.Expires.After(time.Now().UTC()) {
records = append(records, value)
}
}
return records, nil
}
......
......@@ -72,8 +72,8 @@ func (m *mongoDB) List() ([]*CertRecord, error) {
return nil, err
}
var result []*CertRecord
m.collection.Find(nil).All(&result)
return result, nil
err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
return result, err
}
func (m *mongoDB) Revoke(id string) error {
......
......@@ -66,7 +66,7 @@ func NewSQLStore(config string) (CertStorer, error) {
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare get: %v", err)
}
if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
return nil, fmt.Errorf("sqldb: prepare list: %v", err)
}
if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
......@@ -137,7 +137,7 @@ func (db *sqldb) List() ([]*CertRecord, error) {
return nil, err
}
var recs []*CertRecord
rows, _ := db.list.Query()
rows, _ := db.revoked.Query(time.Now().UTC())
defer rows.Close()
for rows.Next() {
cert, err := scanCert(rows)
......
......@@ -42,27 +42,21 @@ func TestParseCertificate(t *testing.T) {
func testStore(t *testing.T, db CertStorer) {
defer db.Close()
ids := []string{"a", "b"}
for _, id := range ids {
r := &CertRecord{
KeyID: id,
Expires: time.Now().UTC().Add(time.Second * -10),
}
if err := db.SetRecord(r); err != nil {
t.Error(err)
}
r := &CertRecord{
KeyID: "a",
Expires: time.Now().UTC().Add(1 * time.Minute),
}
recs, err := db.List()
if err != nil {
if err := db.SetRecord(r); err != nil {
t.Error(err)
}
if len(recs) != len(ids) {
t.Errorf("Want %d records, got %d", len(ids), len(recs))
if _, err := db.List(); err != nil {
t.Error(err)
}
c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert)
cert := c.(*ssh.Certificate)
cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix())
cert.ValidAfter = uint64(time.Now().Add(-5 * time.Minute).UTC().Unix())
if err := db.SetCert(cert); err != nil {
t.Error(err)
}
......@@ -74,9 +68,6 @@ func testStore(t *testing.T, db CertStorer) {
t.Error(err)
}
// A revoked key shouldn't get returned if it's already expired
db.Revoke("a")
revoked, err := db.GetRevoked()
if err != nil {
t.Error(err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment