Add some helper functions to tag the db keys

This commit is contained in:
Sameer Rahmani 2022-07-17 12:46:00 +01:00
parent 78c6297518
commit 49b9afbc83
4 changed files with 77 additions and 17 deletions

View File

@ -34,17 +34,6 @@ var listCmd = &cobra.Command{
}
core.ProcessFiles(state, *files)
// state := core.State{
// }
// blocks, err := core.FindCommentBlocks(&(*files)[0])
// if err != nil {
// panic(err)
// }
// fmt.Printf("%s", blocks)
},
}

View File

@ -18,6 +18,8 @@
package core
import (
"io"
"github.com/cockroachdb/pebble"
log "github.com/sirupsen/logrus"
)
@ -26,6 +28,18 @@ type DB struct {
connection *pebble.DB
}
var FILE int8 = 0
var ITEM int8 = 1
var TYPE int8 = 2
var TAG int8 = 4
func tagFor(value []byte, objType int8) []byte {
// Just to make the stupid go-critic to shutup
tmp := value
tmp = append(tmp, byte(objType))
return tmp
}
func CreateDB(path string) (*DB, error) {
log.Debugf("Opening up the state db at '%s'", path)
db, err := pebble.Open(path, nil)
@ -38,18 +52,67 @@ func CreateDB(path string) (*DB, error) {
}, nil
}
func (db *DB) Set(k []byte, v []byte) error {
return db.connection.Set(k, v, pebble.Sync)
func (db *DB) Set(k []byte, v []byte, item int8) error {
taggedKey := tagFor(k, item)
return db.connection.Set(taggedKey, v, pebble.Sync)
}
func (db *DB) Get(k []byte) ([]byte, error) {
v, closer, err := db.connection.Get(k)
func (db *DB) SetItem(k []byte, v []byte) error {
return db.Set(k, v, ITEM)
}
func (db *DB) SetType(k []byte, v []byte) error {
return db.Set(k, v, TYPE)
}
func (db *DB) SetTAG(k []byte, v []byte) error {
return db.Set(k, v, TAG)
}
func (db *DB) SetFile(k []byte, v []byte) error {
return db.Set(k, v, FILE)
}
func (db *DB) GetString(k string) ([]byte, error) {
v, closer, err := db.connection.Get([]byte(k))
defer closer.Close()
if err != nil {
return nil, err
}
defer closer.Close()
return v, nil
}
func (db *DB) Get(k []byte, defaultValue *[]byte) ([]byte, io.Closer, error) {
v, closer, err := db.connection.Get(k)
if err == pebble.ErrNotFound {
var defaultV []byte
if defaultValue != nil {
defaultV = *defaultValue
}
return defaultV, closer, nil
}
return v, closer, err
}
func (db *DB) GetItem(k []byte, defaultValue *[]byte) ([]byte, io.Closer, error) {
return db.Get(k, defaultValue)
}
func (db *DB) GetStringOrDefault(k string, default_ interface{}) (interface{}, error) {
v, closer, err := db.connection.Get([]byte(k))
if err == pebble.ErrNotFound {
return default_, nil
}
if err != nil {
return nil, err
}
defer closer.Close()
return v, nil
}

View File

@ -26,7 +26,7 @@ var POUND_SIGN = "#"
var C_LANG = Lang{"C", &DOUBLE_SLASH}
var CPP_LANG = Lang{"C++", &DOUBLE_SLASH}
var Go_LANG = Lang{"Go", &DOUBLE_SLASH}
var GO_LANG = Lang{"Go", &DOUBLE_SLASH}
var SH = Lang{"Shell", &POUND_SIGN}
var MLIR = Lang{"MLIR", &DOUBLE_SLASH}
@ -38,5 +38,6 @@ var ExtsToLang map[string]*Lang = map[string]*Lang{
"hpp": &CPP_LANG,
"cpp": &CPP_LANG,
"sh": &SH,
"go": &GO_LANG,
"mlir": &MLIR,
}

View File

@ -27,6 +27,7 @@ import (
)
type State struct {
Version int
ProjectRoot string
Repo *git.Repository
// A mapping of file extensions to the language
@ -62,6 +63,12 @@ func CreateState(projectRoot string, debug bool) (*State, error) {
if err != nil {
return nil, err
}
_, err = state.DB.GetStringOrDefault("version", "1")
if err != nil {
log.Panic(err)
}
// Read the pattern from the global or project wide config file
// using viper
state.BlockTypePattern = regexp.MustCompile(`^(?P<type>[A-Z]+):`)