From b77c8b55800d03b1caf481d93585a5295f67d0da Mon Sep 17 00:00:00 2001 From: 0ceanSlim Date: Tue, 23 Jul 2024 10:53:01 -0400 Subject: [PATCH] events.go removed and funcs added to server --- main.go | 10 +-- server/db/db.go | 34 ++++++++- server/events/events.go | 147 ------------------------------------- server/events/kind0.go | 4 +- server/events/kind1.go | 4 +- server/events/unhandled.go | 3 +- server/server.go | 48 +++++++++--- server/utils/checkSig.go | 73 ++++++++++++++++++ 8 files changed, 152 insertions(+), 171 deletions(-) delete mode 100644 server/events/events.go create mode 100644 server/utils/checkSig.go diff --git a/main.go b/main.go index bc07424..1deba8a 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "grain/server" "grain/server/db" - "grain/server/events" "grain/server/utils" "golang.org/x/net/websocket" @@ -21,16 +20,11 @@ func main() { } // Initialize MongoDB client - client, err := db.InitDB(config.MongoDB.URI, config.MongoDB.Database) + _, err = db.InitDB(config.MongoDB.URI, config.MongoDB.Database) if err != nil { log.Fatal("Error initializing database: ", err) } - defer db.DisconnectDB(client) - - // Initialize collections (dynamically handled in the events package) - events.SetClient(client) - - server.SetClient(client) + defer db.DisconnectDB() // Start WebSocket server http.Handle("/", websocket.Handler(server.Handler)) diff --git a/server/db/db.go b/server/db/db.go index c7cf06d..9ab4250 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -4,14 +4,24 @@ import ( "context" "fmt" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +var client *mongo.Client +var collections = make(map[int]*mongo.Collection) + +// GetClient returns the MongoDB client +func GetClient() *mongo.Client { + return client +} + // Initialize MongoDB client func InitDB(uri, database string) (*mongo.Client, error) { clientOptions := options.Client().ApplyURI(uri) - client, err := mongo.Connect(context.TODO(), clientOptions) + var err error + client, err = mongo.Connect(context.TODO(), clientOptions) if err != nil { return nil, err } @@ -26,10 +36,30 @@ func InitDB(uri, database string) (*mongo.Client, error) { return client, nil } +func GetCollection(kind int) *mongo.Collection { + if collection, exists := collections[kind]; exists { + return collection + } + client := GetClient() + collectionName := fmt.Sprintf("event-kind%d", kind) + collection := client.Database("grain").Collection(collectionName) + collections[kind] = collection + indexModel := mongo.IndexModel{ + Keys: bson.D{{Key: "id", Value: 1}}, + Options: options.Index().SetUnique(true), + } + _, err := collection.Indexes().CreateOne(context.TODO(), indexModel) + if err != nil { + fmt.Printf("Failed to create index on %s: %v\n", collectionName, err) + } + return collection +} + // Disconnect from MongoDB -func DisconnectDB(client *mongo.Client) { +func DisconnectDB() { if err := client.Disconnect(context.TODO()); err != nil { fmt.Println("Error disconnecting from MongoDB:", err) } fmt.Println("Disconnected from MongoDB!") } + diff --git a/server/events/events.go b/server/events/events.go deleted file mode 100644 index 58b4734..0000000 --- a/server/events/events.go +++ /dev/null @@ -1,147 +0,0 @@ -package events - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "log" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/schnorr" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "golang.org/x/net/websocket" -) - -type Event struct { - ID string `json:"id"` - PubKey string `json:"pubkey"` - CreatedAt int64 `json:"created_at"` - Kind int `json:"kind"` - Tags [][]string `json:"tags"` - Content string `json:"content"` - Sig string `json:"sig"` -} - -var ( - client *mongo.Client - collections = make(map[int]*mongo.Collection) -) - -func SetClient(mongoClient *mongo.Client) { - client = mongoClient -} - -func GetCollection(kind int) *mongo.Collection { - if collection, exists := collections[kind]; exists { - return collection - } - collectionName := fmt.Sprintf("event-kind%d", kind) - collection := client.Database("grain").Collection(collectionName) - collections[kind] = collection - indexModel := mongo.IndexModel{ - Keys: bson.D{{Key: "id", Value: 1}}, - Options: options.Index().SetUnique(true), - } - _, err := collection.Indexes().CreateOne(context.TODO(), indexModel) - if err != nil { - fmt.Printf("Failed to create index on %s: %v\n", collectionName, err) - } - return collection -} - -func HandleEvent(ctx context.Context, evt Event, ws *websocket.Conn) { - if !CheckSignature(evt) { - sendOKResponse(ws, evt.ID, false, "invalid: signature verification failed") - return - } - - collection := GetCollection(evt.Kind) - - var err error - switch evt.Kind { - case 0: - err = HandleEventKind0(ctx, evt, collection) - case 1: - err = HandleEventKind1(ctx, evt, collection) - default: - err = HandleUnknownEvent(ctx, evt, collection) - } - - if err != nil { - sendOKResponse(ws, evt.ID, false, fmt.Sprintf("error: %v", err)) - return - } - - sendOKResponse(ws, evt.ID, true, "") -} - -func sendOKResponse(ws *websocket.Conn, eventID string, status bool, message string) { - response := []interface{}{"OK", eventID, status, message} - responseBytes, _ := json.Marshal(response) - websocket.Message.Send(ws, string(responseBytes)) -} - -func SerializeEvent(evt Event) []byte { - eventData := []interface{}{ - 0, - evt.PubKey, - evt.CreatedAt, - evt.Kind, - evt.Tags, - evt.Content, - } - serializedEvent, _ := json.Marshal(eventData) - return serializedEvent -} - -func CheckSignature(evt Event) bool { - serializedEvent := SerializeEvent(evt) - hash := sha256.Sum256(serializedEvent) - eventID := hex.EncodeToString(hash[:]) - if eventID != evt.ID { - log.Printf("Invalid ID: expected %s, got %s\n", eventID, evt.ID) - return false - } - - sigBytes, err := hex.DecodeString(evt.Sig) - if err != nil { - log.Printf("Error decoding signature: %v\n", err) - return false - } - - sig, err := schnorr.ParseSignature(sigBytes) - if err != nil { - log.Printf("Error parsing signature: %v\n", err) - return false - } - - pubKeyBytes, err := hex.DecodeString(evt.PubKey) - if err != nil { - log.Printf("Error decoding public key: %v\n", err) - return false - } - - var pubKey *btcec.PublicKey - if len(pubKeyBytes) == 32 { - // Handle 32-byte public key (x-coordinate only) - pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...)) - } else { - // Handle standard compressed or uncompressed public key - pubKey, err = btcec.ParsePubKey(pubKeyBytes) - } - if err != nil { - log.Printf("Error parsing public key: %v\n", err) - return false - } - - verified := sig.Verify(hash[:], pubKey) - if !verified { - log.Printf("Signature verification failed for event ID: %s\n", evt.ID) - } - - return verified -} diff --git a/server/events/kind0.go b/server/events/kind0.go index 6db5e3e..ac0fb18 100644 --- a/server/events/kind0.go +++ b/server/events/kind0.go @@ -4,12 +4,14 @@ import ( "context" "fmt" + server "grain/server/types" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) -func HandleEventKind0(ctx context.Context, evt Event, collection *mongo.Collection) error { +func HandleKind0(ctx context.Context, evt server.Event, collection *mongo.Collection) error { // Replace the existing event if it has the same pubkey filter := bson.M{"pubkey": evt.PubKey} update := bson.M{ diff --git a/server/events/kind1.go b/server/events/kind1.go index d7bd206..952c0ab 100644 --- a/server/events/kind1.go +++ b/server/events/kind1.go @@ -4,10 +4,12 @@ import ( "context" "fmt" + server "grain/server/types" + "go.mongodb.org/mongo-driver/mongo" ) -func HandleEventKind1(ctx context.Context, evt Event, collection *mongo.Collection) error { +func HandleKind1(ctx context.Context, evt server.Event, collection *mongo.Collection) error { // Insert event into MongoDB _, err := collection.InsertOne(ctx, evt) if err != nil { diff --git a/server/events/unhandled.go b/server/events/unhandled.go index 9cb2dda..62069b1 100644 --- a/server/events/unhandled.go +++ b/server/events/unhandled.go @@ -3,11 +3,12 @@ package events import ( "context" "fmt" + server "grain/server/types" "go.mongodb.org/mongo-driver/mongo" ) -func HandleUnknownEvent(ctx context.Context, evt Event, collection *mongo.Collection) error { +func HandleUnknownKind(ctx context.Context, evt server.Event, collection *mongo.Collection) error { _, err := collection.InsertOne(ctx, evt) if err != nil { return fmt.Errorf("Error inserting unknown event into MongoDB: %v", err) diff --git a/server/server.go b/server/server.go index a6f7c47..316d3bc 100644 --- a/server/server.go +++ b/server/server.go @@ -4,21 +4,15 @@ import ( "context" "encoding/json" "fmt" + "grain/server/db" "grain/server/events" server "grain/server/types" "grain/server/utils" - "go.mongodb.org/mongo-driver/mongo" "golang.org/x/net/websocket" ) var subscriptions = make(map[string]server.Subscription) -var client *mongo.Client - -func SetClient(mongoClient *mongo.Client) { - client = mongoClient - events.SetClient(mongoClient) // Ensure the events package has the MongoDB client -} func Handler(ws *websocket.Conn) { defer ws.Close() @@ -80,19 +74,51 @@ func handleEvent(ws *websocket.Conn, message []interface{}) { return } - var evt events.Event + var evt server.Event err = json.Unmarshal(eventBytes, &evt) if err != nil { fmt.Println("Error unmarshaling event data:", err) return } - // Call the HandleEvent function from the events package - events.HandleEvent(context.TODO(), evt, ws) + // Call the HandleKind function from the events package + HandleKind(context.TODO(), evt, ws) fmt.Println("Event processed:", evt.ID) } +func HandleKind(ctx context.Context, evt server.Event, ws *websocket.Conn) { + if !utils.CheckSignature(evt) { + sendOKResponse(ws, evt.ID, false, "invalid: signature verification failed") + return + } + + collection := db.GetCollection(evt.Kind) + + var err error + switch evt.Kind { + case 0: + err = events.HandleKind0(ctx, evt, collection) + case 1: + err = events.HandleKind1(ctx, evt, collection) + default: + err = events.HandleUnknownKind(ctx, evt, collection) + } + + if err != nil { + sendOKResponse(ws, evt.ID, false, fmt.Sprintf("error: %v", err)) + return + } + + sendOKResponse(ws, evt.ID, true, "") +} + +func sendOKResponse(ws *websocket.Conn, eventID string, status bool, message string) { + response := []interface{}{"OK", eventID, status, message} + responseBytes, _ := json.Marshal(response) + websocket.Message.Send(ws, string(responseBytes)) +} + func handleReq(ws *websocket.Conn, message []interface{}) { if len(message) < 3 { fmt.Println("Invalid REQ message format") @@ -129,7 +155,7 @@ func handleReq(ws *websocket.Conn, message []interface{}) { fmt.Println("Subscription added:", subID) // Query the database with filters and send back the results - queriedEvents, err := QueryEvents(filters, client, "grain", "event-kind1") + queriedEvents, err := QueryEvents(filters, db.GetClient(), "grain", "event-kind1") if err != nil { fmt.Println("Error querying events:", err) return diff --git a/server/utils/checkSig.go b/server/utils/checkSig.go new file mode 100644 index 0000000..ded613b --- /dev/null +++ b/server/utils/checkSig.go @@ -0,0 +1,73 @@ +package utils + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "log" + + server "grain/server/types" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" +) +func SerializeEvent(evt server.Event) []byte { + eventData := []interface{}{ + 0, + evt.PubKey, + evt.CreatedAt, + evt.Kind, + evt.Tags, + evt.Content, + } + serializedEvent, _ := json.Marshal(eventData) + return serializedEvent +} + +func CheckSignature(evt server.Event) bool { + serializedEvent := SerializeEvent(evt) + hash := sha256.Sum256(serializedEvent) + eventID := hex.EncodeToString(hash[:]) + if eventID != evt.ID { + log.Printf("Invalid ID: expected %s, got %s\n", eventID, evt.ID) + return false + } + + sigBytes, err := hex.DecodeString(evt.Sig) + if err != nil { + log.Printf("Error decoding signature: %v\n", err) + return false + } + + sig, err := schnorr.ParseSignature(sigBytes) + if err != nil { + log.Printf("Error parsing signature: %v\n", err) + return false + } + + pubKeyBytes, err := hex.DecodeString(evt.PubKey) + if err != nil { + log.Printf("Error decoding public key: %v\n", err) + return false + } + + var pubKey *btcec.PublicKey + if len(pubKeyBytes) == 32 { + // Handle 32-byte public key (x-coordinate only) + pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...)) + } else { + // Handle standard compressed or uncompressed public key + pubKey, err = btcec.ParsePubKey(pubKeyBytes) + } + if err != nil { + log.Printf("Error parsing public key: %v\n", err) + return false + } + + verified := sig.Verify(hash[:], pubKey) + if !verified { + log.Printf("Signature verification failed for event ID: %s\n", evt.ID) + } + + return verified +}