196 lines
4.1 KiB
Go
196 lines
4.1 KiB
Go
package db
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/syndtr/goleveldb/leveldb/opt"
|
|
tmdb "github.com/tendermint/tm-db"
|
|
)
|
|
|
|
// getDB initializes the database. If it fails, it returns nil silently.
|
|
func getDB(name string, readOnly bool) tmdb.DB {
|
|
db, err := tmdb.NewGoLevelDBWithOpts(name, "./data", &opt.Options{ReadOnly: readOnly})
|
|
if err != nil {
|
|
// Extension database is optional, skip silently
|
|
return nil
|
|
}
|
|
return db
|
|
}
|
|
|
|
// GetTable returns a new Table instance for the generic type T, ensuring the struct has an "Id" field.
|
|
func GetTable[T any]() *Table[T] {
|
|
var t T
|
|
typeName := reflect.TypeOf(t).Name()
|
|
if !hasIdField(t) {
|
|
panic(fmt.Sprintf("Table %s must have a field named 'Id'", typeName))
|
|
}
|
|
return &Table[T]{name: typeName}
|
|
}
|
|
|
|
// hasIdField checks if the struct has a field named "Id".
|
|
func hasIdField[T any](t T) bool {
|
|
val := reflect.Indirect(reflect.ValueOf(t))
|
|
|
|
if val.Kind() != reflect.Struct {
|
|
return false
|
|
}
|
|
|
|
return val.FieldByName("Id").IsValid()
|
|
}
|
|
|
|
// getIdBytes converts an ID to its byte representation for storage in the database.
|
|
func getIdBytes(id any) []byte {
|
|
res, err := SerializeKey(id)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return res
|
|
}
|
|
|
|
// if id == nil {
|
|
// return nil
|
|
// }
|
|
|
|
// var buf bytes.Buffer
|
|
// switch v := id.(type) {
|
|
// case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
// if err := binary.Write(&buf, binary.BigEndian, v); err == nil {
|
|
// return buf.Bytes()
|
|
// }
|
|
// case string:
|
|
// return []byte(v)
|
|
// case []byte:
|
|
// return v
|
|
// default:
|
|
// return []byte(fmt.Sprint(v))
|
|
// }
|
|
// return nil
|
|
// }
|
|
|
|
// getId retrieves the "Id" field from the struct T.
|
|
func getId[T any](t T) any {
|
|
val := reflect.Indirect(reflect.ValueOf(t))
|
|
|
|
if val.Kind() != reflect.Struct {
|
|
return nil
|
|
}
|
|
|
|
field := val.FieldByName("Id")
|
|
if field.IsValid() {
|
|
return field.Interface()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Table represents a database table for generic type T.
|
|
type Table[T any] struct {
|
|
name string
|
|
}
|
|
|
|
// All retrieves all entries from the database and unmarshals them into a slice of T.
|
|
func (tbl *Table[T]) All() ([]*T, error) {
|
|
db := getDB(tbl.name, true)
|
|
if db == nil {
|
|
return nil, fmt.Errorf("failed to open database %s", tbl.name)
|
|
}
|
|
defer db.Close()
|
|
|
|
var items []*T
|
|
iter, err := db.Iterator(nil, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer iter.Close()
|
|
|
|
for ; iter.Valid(); iter.Next() {
|
|
|
|
item, err := Deserialize[T](iter.Value())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func Serialize(data any) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
enc := gob.NewEncoder(&buf)
|
|
err := enc.Encode(data)
|
|
return buf.Bytes(), err
|
|
|
|
// return json.Marshal(data)
|
|
}
|
|
|
|
func SerializeKey(data any) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
enc := gob.NewEncoder(&buf)
|
|
err := enc.Encode(data)
|
|
return buf.Bytes(), err
|
|
}
|
|
|
|
func Deserialize[T any](data []byte) (*T, error) {
|
|
var obj T
|
|
buf := bytes.NewBuffer(data)
|
|
dec := gob.NewDecoder(buf)
|
|
err := dec.Decode(&obj)
|
|
return &obj, err
|
|
|
|
// return &obj, json.Unmarshal(data, &obj)
|
|
}
|
|
|
|
// UpdateInsert inserts or updates multiple items in the database.
|
|
func (tbl *Table[T]) UpdateInsert(items ...*T) error {
|
|
db := getDB(tbl.name, false)
|
|
if db == nil {
|
|
return fmt.Errorf("failed to open database %s", tbl.name)
|
|
}
|
|
defer db.Close()
|
|
|
|
for _, item := range items {
|
|
// b, err := json.Marshal(item)
|
|
b, err := Serialize(item)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := db.Set(getIdBytes(getId(item)), b); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Delete removes entries by their IDs.
|
|
func (tbl *Table[T]) Delete(ids ...any) error {
|
|
db := getDB(tbl.name, false)
|
|
if db == nil {
|
|
return fmt.Errorf("failed to open database %s", tbl.name)
|
|
}
|
|
defer db.Close()
|
|
|
|
for _, id := range ids {
|
|
if err := db.Delete(getIdBytes(id)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a single item by its ID.
|
|
func (tbl *Table[T]) Get(id any) (*T, error) {
|
|
db := getDB(tbl.name, true)
|
|
if db == nil {
|
|
return nil, fmt.Errorf("failed to open database %s", tbl.name)
|
|
}
|
|
defer db.Close()
|
|
|
|
b, err := db.Get(getIdBytes(id))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return Deserialize[T](b)
|
|
}
|