WIP: Add event handling with Schnorr signature verification

This commit is contained in:
0ceanSlim 2024-07-22 13:50:43 -04:00
parent c7b0e9c390
commit a0c5fe95ad
8 changed files with 238 additions and 163 deletions

View File

@ -2,11 +2,17 @@ package events
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/net/websocket"
) )
type Event struct { type Event struct {
@ -14,47 +20,128 @@ type Event struct {
PubKey string `json:"pubkey"` PubKey string `json:"pubkey"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
Kind int `json:"kind"` Kind int `json:"kind"`
Tags []string `json:"tags"` Tags [][]string `json:"tags"`
Content string `json:"content"` Content string `json:"content"`
Sig string `json:"sig"` Sig string `json:"sig"`
} }
var eventKind0Collection *mongo.Collection var collections = make(map[int]*mongo.Collection)
var eventKind1Collection *mongo.Collection
func InitCollections(client *mongo.Client, eventKind0, eventKind1 string) {
eventKind0Collection = client.Database("grain").Collection(eventKind0)
eventKind1Collection = client.Database("grain").Collection(eventKind1)
func InitCollections(client *mongo.Client, kinds ...int) {
for _, kind := range kinds {
collectionName := fmt.Sprintf("event-kind%d", kind)
collections[kind] = client.Database("grain").Collection(collectionName)
indexModel := mongo.IndexModel{ indexModel := mongo.IndexModel{
Keys: bson.D{{Key: "id", Value: 1}}, Keys: bson.D{{Key: "id", Value: 1}},
Options: options.Index().SetUnique(true), Options: options.Index().SetUnique(true),
} }
_, err := eventKind0Collection.Indexes().CreateOne(context.TODO(), indexModel) _, err := collections[kind].Indexes().CreateOne(context.TODO(), indexModel)
if err != nil { if err != nil {
fmt.Println("Failed to create index on event-kind0: ", err) fmt.Printf("Failed to create index on %s: %v\n", collectionName, err)
} }
_, err = eventKind1Collection.Indexes().CreateOne(context.TODO(), indexModel)
if err != nil {
fmt.Println("Failed to create index on event-kind1: ", err)
} }
} }
func HandleEvent(ctx context.Context, evt Event) error { func GetCollection(kind int, client *mongo.Client) *mongo.Collection {
if collection, exists := collections[kind]; exists {
return collection
}
collectionName := fmt.Sprintf("event-kind%d", kind)
collection := client.Database("grain").Collection(collectionName)
collections[kind] = collection
indexModel := mongo.IndexModel{
Keys: bson.D{{Key: "id", Value: 1}},
Options: options.Index().SetUnique(true),
}
_, err := collection.Indexes().CreateOne(context.TODO(), indexModel)
if err != nil {
fmt.Printf("Failed to create index on %s: %v\n", collectionName, err)
}
return collection
}
func HandleEvent(ctx context.Context, evt Event, client *mongo.Client, ws *websocket.Conn) {
if !ValidateEvent(evt) {
sendOKResponse(ws, evt.ID, false, "invalid: signature verification failed")
return
}
collection := GetCollection(evt.Kind, client)
var err error
switch evt.Kind { switch evt.Kind {
case 0: case 0:
return HandleEventKind0(ctx, evt, eventKind0Collection) err = HandleEventKind0(ctx, evt, collection)
case 1: case 1:
return HandleEventKind1(ctx, evt, eventKind1Collection) err = HandleEventKind1(ctx, evt, collection)
default: default:
fmt.Println("Unknown event kind:", evt.Kind) err = HandleDefaultEvent(ctx, evt, collection)
return fmt.Errorf("unknown event kind: %d", evt.Kind)
}
} }
func GetCollections() map[string]*mongo.Collection { if err != nil {
return map[string]*mongo.Collection{ sendOKResponse(ws, evt.ID, false, fmt.Sprintf("error: %v", err))
"eventKind0": eventKind0Collection, return
"eventKind1": eventKind1Collection,
} }
sendOKResponse(ws, evt.ID, true, "")
}
func sendOKResponse(ws *websocket.Conn, eventID string, status bool, message string) {
response := []interface{}{"OK", eventID, status, message}
responseBytes, _ := json.Marshal(response)
websocket.Message.Send(ws, string(responseBytes))
}
func ValidateEvent(evt Event) bool {
serializedEvent := SerializeEvent(evt)
hash := sha256.Sum256(serializedEvent)
eventID := hex.EncodeToString(hash[:])
if eventID != evt.ID {
return false
}
sigBytes, err := hex.DecodeString(evt.Sig)
if err != nil {
return false
}
sig, err := schnorr.ParseSignature(sigBytes)
if err != nil {
return false
}
pubKeyBytes, err := hex.DecodeString(evt.PubKey)
if err != nil {
return false
}
pubKey, err := btcec.ParsePubKey(pubKeyBytes)
if err != nil {
return false
}
return sig.Verify(hash[:], pubKey)
}
func SerializeEvent(evt Event) []byte {
eventData := []interface{}{
0,
evt.PubKey,
evt.CreatedAt,
evt.Kind,
evt.Tags,
evt.Content,
}
serializedEvent, _ := json.Marshal(eventData)
return serializedEvent
}
func HandleDefaultEvent(ctx context.Context, evt Event, collection *mongo.Collection) error {
_, err := collection.InsertOne(ctx, evt)
if err != nil {
return fmt.Errorf("Error inserting default event into MongoDB: %v", err)
}
fmt.Println("Inserted default event into MongoDB:", evt.ID)
return nil
} }

View File

@ -10,11 +10,6 @@ import (
) )
func HandleEventKind0(ctx context.Context, evt Event, collection *mongo.Collection) error { func HandleEventKind0(ctx context.Context, evt Event, collection *mongo.Collection) error {
// Perform specific validation for event kind 0
if !isValidEventKind0(evt) {
return fmt.Errorf("validation failed for event kind 0: %s", evt.ID)
}
// Replace the existing event if it has the same pubkey // Replace the existing event if it has the same pubkey
filter := bson.M{"pubkey": evt.PubKey} filter := bson.M{"pubkey": evt.PubKey}
update := bson.M{ update := bson.M{
@ -28,21 +23,12 @@ func HandleEventKind0(ctx context.Context, evt Event, collection *mongo.Collecti
}, },
} }
options := options.Update().SetUpsert(true) // Insert if not found opts := options.Update().SetUpsert(true) // Insert if not found
_, err := collection.UpdateOne(ctx, filter, update, options) _, err := collection.UpdateOne(ctx, filter, update, opts)
if err != nil { if err != nil {
fmt.Println("Error updating/inserting event kind 0 into MongoDB:", err) return fmt.Errorf("Error updating/inserting event kind 0 into MongoDB: %v", err)
return err
} }
fmt.Println("Upserted event kind 0 into MongoDB:", evt.ID) fmt.Println("Upserted event kind 0 into MongoDB:", evt.ID)
return nil return nil
} }
func isValidEventKind0(evt Event) bool {
// Placeholder for actual validation logic
if evt.Content == "" {
return false
}
return true
}

View File

@ -8,26 +8,12 @@ import (
) )
func HandleEventKind1(ctx context.Context, evt Event, collection *mongo.Collection) error { func HandleEventKind1(ctx context.Context, evt Event, collection *mongo.Collection) error {
// Perform specific validation for event kind 1
if !isValidEventKind1(evt) {
return fmt.Errorf("validation failed for event kind 1: %s", evt.ID)
}
// Insert event into MongoDB // Insert event into MongoDB
_, err := collection.InsertOne(ctx, evt) _, err := collection.InsertOne(ctx, evt)
if err != nil { if err != nil {
fmt.Println("Error inserting event into MongoDB:", err) return fmt.Errorf("Error inserting event into MongoDB: %v", err)
return err
} }
fmt.Println("Inserted event kind 1 into MongoDB:", evt.ID) fmt.Println("Inserted event kind 1 into MongoDB:", evt.ID)
return nil return nil
} }
func isValidEventKind1(evt Event) bool {
// Placeholder for actual validation logic
if evt.Content == "" {
return false
}
return true
}

4
go.mod
View File

@ -8,6 +8,10 @@ require (
) )
require ( require (
github.com/btcsuite/btcd/btcec/v2 v2.3.4 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 // indirect
github.com/decred/dcrd/crypto/blake256 v1.0.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/klauspost/compress v1.13.6 // indirect github.com/klauspost/compress v1.13.6 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect github.com/montanaflynn/stats v0.7.1 // indirect

8
go.sum
View File

@ -1,5 +1,13 @@
github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurTXGPFfiQ=
github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0=
github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=

View File

@ -28,7 +28,7 @@ func main() {
defer db.DisconnectDB(client) defer db.DisconnectDB(client)
// Initialize collections // Initialize collections
events.InitCollections(client, config.Collections.EventKind0, config.Collections.EventKind1) events.InitCollections(client, 0, 1) // Initialize known kinds
server.SetClient(client) server.SetClient(client)

View File

@ -7,6 +7,8 @@ import (
"grain/events" "grain/events"
"time" "time"
"grain/utils"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -101,11 +103,7 @@ func handleEvent(ws *websocket.Conn, message []interface{}) {
return return
} }
err = events.HandleEvent(context.TODO(), evt) events.HandleEvent(context.TODO(), evt, client, ws)
if err != nil {
fmt.Println("Error handling event:", err)
return
}
fmt.Println("Event processed:", evt.ID) fmt.Println("Event processed:", evt.ID)
} }
@ -131,13 +129,13 @@ func handleReq(ws *websocket.Conn, message []interface{}) {
} }
var f Filter var f Filter
f.IDs = toStringArray(filterData["ids"]) f.IDs = utils.ToStringArray(filterData["ids"])
f.Authors = toStringArray(filterData["authors"]) f.Authors = utils.ToStringArray(filterData["authors"])
f.Kinds = toIntArray(filterData["kinds"]) f.Kinds = utils.ToIntArray(filterData["kinds"])
f.Tags = toTagsMap(filterData["tags"]) f.Tags = utils.ToTagsMap(filterData["tags"])
f.Since = toTime(filterData["since"]) f.Since = utils.ToTime(filterData["since"])
f.Until = toTime(filterData["until"]) f.Until = utils.ToTime(filterData["until"])
f.Limit = toInt(filterData["limit"]) f.Limit = utils.ToInt(filterData["limit"])
filters[i] = f filters[i] = f
} }
@ -195,92 +193,3 @@ func handleClose(ws *websocket.Conn, message []interface{}) {
return return
} }
} }
func toStringArray(i interface{}) []string {
if i == nil {
return nil
}
arr, ok := i.([]interface{})
if !ok {
return nil
}
var result []string
for _, v := range arr {
str, ok := v.(string)
if ok {
result = append(result, str)
}
}
return result
}
func toIntArray(i interface{}) []int {
if i == nil {
return nil
}
arr, ok := i.([]interface{})
if !ok {
return nil
}
var result []int
for _, v := range arr {
num, ok := v.(float64)
if ok {
result = append(result, int(num))
}
}
return result
}
func toTagsMap(i interface{}) map[string][]string {
if i == nil {
return nil
}
tags, ok := i.(map[string]interface{})
if !ok {
return nil
}
result := make(map[string][]string)
for k, v := range tags {
result[k] = toStringArray(v)
}
return result
}
func toInt64(i interface{}) *int64 {
if i == nil {
return nil
}
num, ok := i.(float64)
if !ok {
return nil
}
val := int64(num)
return &val
}
func toInt(i interface{}) *int {
if i == nil {
return nil
}
num, ok := i.(float64)
if !ok {
return nil
}
val := int(num)
return &val
}
func toTime(data interface{}) *time.Time {
if data == nil {
return nil
}
// Ensure data is a float64 which MongoDB uses for numbers
timestamp, ok := data.(float64)
if !ok {
fmt.Println("Invalid timestamp format")
return nil
}
t := time.Unix(int64(timestamp), 0).UTC()
return &t
}

95
utils/decode.go Normal file
View File

@ -0,0 +1,95 @@
package utils
import (
"fmt"
"time"
)
func ToStringArray(i interface{}) []string {
if i == nil {
return nil
}
arr, ok := i.([]interface{})
if !ok {
return nil
}
var result []string
for _, v := range arr {
str, ok := v.(string)
if ok {
result = append(result, str)
}
}
return result
}
func ToIntArray(i interface{}) []int {
if i == nil {
return nil
}
arr, ok := i.([]interface{})
if !ok {
return nil
}
var result []int
for _, v := range arr {
num, ok := v.(float64)
if ok {
result = append(result, int(num))
}
}
return result
}
func ToTagsMap(i interface{}) map[string][]string {
if i == nil {
return nil
}
tags, ok := i.(map[string]interface{})
if !ok {
return nil
}
result := make(map[string][]string)
for k, v := range tags {
result[k] = ToStringArray(v)
}
return result
}
func ToInt64(i interface{}) *int64 {
if i == nil {
return nil
}
num, ok := i.(float64)
if !ok {
return nil
}
val := int64(num)
return &val
}
func ToInt(i interface{}) *int {
if i == nil {
return nil
}
num, ok := i.(float64)
if !ok {
return nil
}
val := int(num)
return &val
}
func ToTime(data interface{}) *time.Time {
if data == nil {
return nil
}
// Ensure data is a float64 which MongoDB uses for numbers
timestamp, ok := data.(float64)
if !ok {
fmt.Println("Invalid timestamp format")
return nil
}
t := time.Unix(int64(timestamp), 0).UTC()
return &t
}