commit 4c71bc58e916721df7d510e0aba57946f2b128da
Author: Přemysl Janouch
Date: Mon Nov 20 21:10:16 2017 +0100
Initial commit
The most basic REST database you'll find.
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c870b20
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+/tinydb
diff --git a/db.go b/db.go
new file mode 100644
index 0000000..c320ebe
--- /dev/null
+++ b/db.go
@@ -0,0 +1,209 @@
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "sync"
+)
+
+var ErrClosed = errors.New("database has been closed")
+
+type DB struct {
+ sync.RWMutex // Locking
+ data map[string]string // Current state of the database
+ file *os.File // Data storage
+}
+
+// Open an existing database file, loading the contents to memory
+func OpenDB(path string) (*DB, error) {
+ file, err := os.OpenFile(path, os.O_RDWR, 0 /* not used */)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO we might want a recover flag that just reads as much as it can
+ // instead of returning io.ErrUnexpectedEOF
+ db := &DB{data: make(map[string]string), file: file}
+ for {
+ var header struct{ KeyLen, ValueLen int32 }
+ err := binary.Read(db.file, binary.LittleEndian, &header)
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ if header.KeyLen < 0 {
+ return nil, fmt.Errorf("invalid key length: %d", header.KeyLen)
+ }
+ key := make([]byte, header.KeyLen)
+ if n, err := file.Read(key); err != nil {
+ return nil, err
+ } else if n != len(key) {
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ if header.ValueLen < 0 {
+ delete(db.data, string(key))
+ continue
+ }
+ value := make([]byte, header.ValueLen)
+ if n, err := file.Read(value); err != nil {
+ return nil, err
+ } else if n != len(value) {
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ db.data[string(key)] = string(value)
+ }
+ // We've been successful, clean up after failed snapshots
+ os.Remove(path + ".1")
+ return db, nil
+}
+
+// Create a new database, overwriting any previous contents of the file
+func CreateDB(path string) (*DB, error) {
+ file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, err
+ }
+ return &DB{data: make(map[string]string), file: file}, nil
+}
+
+// Retrieve the value corresponding to the given key
+func (db *DB) Get(key string) (string, bool, error) {
+ db.RLock()
+ defer db.RUnlock()
+
+ if db.file == nil {
+ return "", false, ErrClosed
+ }
+
+ value, ok := db.data[key]
+ return value, ok, nil
+}
+
+func put(file *os.File, key, value string) error {
+ header := [2]int32{int32(len(key)), int32(len(value))}
+ if err := binary.Write(file, binary.LittleEndian, &header); err != nil {
+ return err
+ }
+
+ if n, err := file.WriteString(key); err != nil {
+ return err
+ } else if n != len(key) {
+ return io.ErrShortWrite
+ }
+
+ if n, err := file.WriteString(value); err != nil {
+ return err
+ } else if n != len(value) {
+ return io.ErrShortWrite
+ }
+ return nil
+}
+
+// Save a key-value pair in the database storage
+func (db *DB) Put(key, value string) error {
+ db.Lock()
+ defer db.Unlock()
+
+ if db.file == nil {
+ return ErrClosed
+ }
+
+ // XXX we should check whether the key and the value fit
+ if err := put(db.file, key, value); err != nil {
+ return err
+ }
+ if err := db.file.Sync(); err != nil {
+ return err
+ }
+
+ db.data[key] = value
+ return nil
+}
+
+// Delete a key from the database storage
+func (db *DB) Delete(key string) error {
+ db.Lock()
+ defer db.Unlock()
+
+ if db.file == nil {
+ return ErrClosed
+ }
+
+ // Like put(), just without the "value"
+ header := [2]int32{int32(len(key)), -1}
+ if err := binary.Write(db.file, binary.LittleEndian, &header); err != nil {
+ return err
+ }
+
+ if n, err := db.file.WriteString(key); err != nil {
+ return err
+ } else if n != len(key) {
+ return io.ErrShortWrite
+ }
+
+ if err := db.file.Sync(); err != nil {
+ return err
+ }
+
+ // TODO maybe return an indication whether anything was actually deleted
+ delete(db.data, key)
+ return nil
+}
+
+// Get rid of historical data in the database file
+func (db *DB) Checkpoint() error {
+ db.Lock()
+ defer db.Unlock()
+
+ if db.file == nil {
+ return ErrClosed
+ }
+
+ checkpoint, err := os.OpenFile(db.file.Name()+".1",
+ os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ return err
+ }
+
+ // The checkpoint is made of the current value of every present key
+ for key, value := range db.data {
+ if err := put(checkpoint, key, value); err != nil {
+ return err
+ }
+ }
+ if err := checkpoint.Sync(); err != nil {
+ return err
+ }
+
+ // Atomically move the checkpoint over the old database file
+ if err := os.Rename(checkpoint.Name(), db.file.Name()); err != nil {
+ // Not sure how much sense this makes--when do we get here?
+ _ = os.Remove(checkpoint.Name())
+ return err
+ }
+ // The old file now points to unlinked storage, replace it with the new one
+ _ = db.file.Close()
+ db.file = checkpoint
+ return nil
+}
+
+// Close the database file, rendering the object unusable
+func (db *DB) Close() error {
+ db.Lock()
+ defer db.Unlock()
+
+ if db.file != nil {
+ return nil
+ }
+
+ err := db.file.Close()
+ db.file = nil
+ return err
+}
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..f432b32
--- /dev/null
+++ b/main.go
@@ -0,0 +1,87 @@
+// Demos a trivial key-value database backed by a file
+package main
+
+import (
+ "context"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+)
+
+func main() {
+ if len(os.Args) != 3 {
+ log.Fatalln("usage: %s LISTEN-ADDRESS DATABASE-FILE", os.Args[0])
+ }
+
+ listenAddr, dbFilename := os.Args[1], os.Args[2]
+ db, err := OpenDB(dbFilename)
+ if err != nil && os.IsNotExist(err) {
+ log.Println("database file does not exist, creating")
+ db, err = CreateDB(dbFilename)
+ }
+ if err != nil {
+ log.Fatalln(err)
+ }
+
+ http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ if _, ok := r.URL.Query()["checkpoint"]; ok && r.Method == "GET" {
+ if err := db.Checkpoint(); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ return
+ }
+
+ byteValue, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+ key, value := strings.TrimPrefix(r.URL.Path, "/"), string(byteValue)
+ switch r.Method {
+ case "GET":
+ if value, ok, err := db.Get(key); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else if ok {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(value))
+ } else {
+ w.WriteHeader(http.StatusNotFound)
+ }
+ case "PUT":
+ if err := db.Put(key, value); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ case "DELETE":
+ if err := db.Delete(key); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ default:
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ }
+ })
+
+ server := &http.Server{Addr: listenAddr}
+ go func() {
+ if err := http.ListenAndServe(listenAddr, nil); err != nil &&
+ err != http.ErrServerClosed {
+ log.Fatalln(err)
+ }
+ }()
+
+ sig := make(chan os.Signal)
+ signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
+ <-sig
+
+ // For simplicity, we'll wait for everything to finish, including snapshots
+ if err := server.Shutdown(context.Background()); err != nil {
+ log.Fatalln(err)
+ }
+ if err := db.Close(); err != nil {
+ log.Fatalln(err)
+ }
+}