events.go removed and funcs added to server

This commit is contained in:
0ceanSlim 2024-07-23 10:53:01 -04:00
parent a60104d73d
commit b77c8b5580
8 changed files with 152 additions and 171 deletions

10
main.go
View File

@ -7,7 +7,6 @@ import (
"grain/server" "grain/server"
"grain/server/db" "grain/server/db"
"grain/server/events"
"grain/server/utils" "grain/server/utils"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
@ -21,16 +20,11 @@ func main() {
} }
// Initialize MongoDB client // Initialize MongoDB client
client, err := db.InitDB(config.MongoDB.URI, config.MongoDB.Database) _, err = db.InitDB(config.MongoDB.URI, config.MongoDB.Database)
if err != nil { if err != nil {
log.Fatal("Error initializing database: ", err) log.Fatal("Error initializing database: ", err)
} }
defer db.DisconnectDB(client) defer db.DisconnectDB()
// Initialize collections (dynamically handled in the events package)
events.SetClient(client)
server.SetClient(client)
// Start WebSocket server // Start WebSocket server
http.Handle("/", websocket.Handler(server.Handler)) http.Handle("/", websocket.Handler(server.Handler))

View File

@ -4,14 +4,24 @@ import (
"context" "context"
"fmt" "fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
) )
var client *mongo.Client
var collections = make(map[int]*mongo.Collection)
// GetClient returns the MongoDB client
func GetClient() *mongo.Client {
return client
}
// Initialize MongoDB client // Initialize MongoDB client
func InitDB(uri, database string) (*mongo.Client, error) { func InitDB(uri, database string) (*mongo.Client, error) {
clientOptions := options.Client().ApplyURI(uri) clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(context.TODO(), clientOptions) var err error
client, err = mongo.Connect(context.TODO(), clientOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -26,10 +36,30 @@ func InitDB(uri, database string) (*mongo.Client, error) {
return client, nil return client, nil
} }
func GetCollection(kind int) *mongo.Collection {
if collection, exists := collections[kind]; exists {
return collection
}
client := GetClient()
collectionName := fmt.Sprintf("event-kind%d", kind)
collection := client.Database("grain").Collection(collectionName)
collections[kind] = collection
indexModel := mongo.IndexModel{
Keys: bson.D{{Key: "id", Value: 1}},
Options: options.Index().SetUnique(true),
}
_, err := collection.Indexes().CreateOne(context.TODO(), indexModel)
if err != nil {
fmt.Printf("Failed to create index on %s: %v\n", collectionName, err)
}
return collection
}
// Disconnect from MongoDB // Disconnect from MongoDB
func DisconnectDB(client *mongo.Client) { func DisconnectDB() {
if err := client.Disconnect(context.TODO()); err != nil { if err := client.Disconnect(context.TODO()); err != nil {
fmt.Println("Error disconnecting from MongoDB:", err) fmt.Println("Error disconnecting from MongoDB:", err)
} }
fmt.Println("Disconnected from MongoDB!") fmt.Println("Disconnected from MongoDB!")
} }

View File

@ -1,147 +0,0 @@
package events
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/net/websocket"
)
type Event struct {
ID string `json:"id"`
PubKey string `json:"pubkey"`
CreatedAt int64 `json:"created_at"`
Kind int `json:"kind"`
Tags [][]string `json:"tags"`
Content string `json:"content"`
Sig string `json:"sig"`
}
var (
client *mongo.Client
collections = make(map[int]*mongo.Collection)
)
func SetClient(mongoClient *mongo.Client) {
client = mongoClient
}
func GetCollection(kind int) *mongo.Collection {
if collection, exists := collections[kind]; exists {
return collection
}
collectionName := fmt.Sprintf("event-kind%d", kind)
collection := client.Database("grain").Collection(collectionName)
collections[kind] = collection
indexModel := mongo.IndexModel{
Keys: bson.D{{Key: "id", Value: 1}},
Options: options.Index().SetUnique(true),
}
_, err := collection.Indexes().CreateOne(context.TODO(), indexModel)
if err != nil {
fmt.Printf("Failed to create index on %s: %v\n", collectionName, err)
}
return collection
}
func HandleEvent(ctx context.Context, evt Event, ws *websocket.Conn) {
if !CheckSignature(evt) {
sendOKResponse(ws, evt.ID, false, "invalid: signature verification failed")
return
}
collection := GetCollection(evt.Kind)
var err error
switch evt.Kind {
case 0:
err = HandleEventKind0(ctx, evt, collection)
case 1:
err = HandleEventKind1(ctx, evt, collection)
default:
err = HandleUnknownEvent(ctx, evt, collection)
}
if err != nil {
sendOKResponse(ws, evt.ID, false, fmt.Sprintf("error: %v", err))
return
}
sendOKResponse(ws, evt.ID, true, "")
}
func sendOKResponse(ws *websocket.Conn, eventID string, status bool, message string) {
response := []interface{}{"OK", eventID, status, message}
responseBytes, _ := json.Marshal(response)
websocket.Message.Send(ws, string(responseBytes))
}
func SerializeEvent(evt Event) []byte {
eventData := []interface{}{
0,
evt.PubKey,
evt.CreatedAt,
evt.Kind,
evt.Tags,
evt.Content,
}
serializedEvent, _ := json.Marshal(eventData)
return serializedEvent
}
func CheckSignature(evt Event) bool {
serializedEvent := SerializeEvent(evt)
hash := sha256.Sum256(serializedEvent)
eventID := hex.EncodeToString(hash[:])
if eventID != evt.ID {
log.Printf("Invalid ID: expected %s, got %s\n", eventID, evt.ID)
return false
}
sigBytes, err := hex.DecodeString(evt.Sig)
if err != nil {
log.Printf("Error decoding signature: %v\n", err)
return false
}
sig, err := schnorr.ParseSignature(sigBytes)
if err != nil {
log.Printf("Error parsing signature: %v\n", err)
return false
}
pubKeyBytes, err := hex.DecodeString(evt.PubKey)
if err != nil {
log.Printf("Error decoding public key: %v\n", err)
return false
}
var pubKey *btcec.PublicKey
if len(pubKeyBytes) == 32 {
// Handle 32-byte public key (x-coordinate only)
pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...))
} else {
// Handle standard compressed or uncompressed public key
pubKey, err = btcec.ParsePubKey(pubKeyBytes)
}
if err != nil {
log.Printf("Error parsing public key: %v\n", err)
return false
}
verified := sig.Verify(hash[:], pubKey)
if !verified {
log.Printf("Signature verification failed for event ID: %s\n", evt.ID)
}
return verified
}

View File

@ -4,12 +4,14 @@ import (
"context" "context"
"fmt" "fmt"
server "grain/server/types"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
) )
func HandleEventKind0(ctx context.Context, evt Event, collection *mongo.Collection) error { func HandleKind0(ctx context.Context, evt server.Event, collection *mongo.Collection) error {
// Replace the existing event if it has the same pubkey // Replace the existing event if it has the same pubkey
filter := bson.M{"pubkey": evt.PubKey} filter := bson.M{"pubkey": evt.PubKey}
update := bson.M{ update := bson.M{

View File

@ -4,10 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
server "grain/server/types"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
) )
func HandleEventKind1(ctx context.Context, evt Event, collection *mongo.Collection) error { func HandleKind1(ctx context.Context, evt server.Event, collection *mongo.Collection) error {
// Insert event into MongoDB // Insert event into MongoDB
_, err := collection.InsertOne(ctx, evt) _, err := collection.InsertOne(ctx, evt)
if err != nil { if err != nil {

View File

@ -3,11 +3,12 @@ package events
import ( import (
"context" "context"
"fmt" "fmt"
server "grain/server/types"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
) )
func HandleUnknownEvent(ctx context.Context, evt Event, collection *mongo.Collection) error { func HandleUnknownKind(ctx context.Context, evt server.Event, collection *mongo.Collection) error {
_, err := collection.InsertOne(ctx, evt) _, err := collection.InsertOne(ctx, evt)
if err != nil { if err != nil {
return fmt.Errorf("Error inserting unknown event into MongoDB: %v", err) return fmt.Errorf("Error inserting unknown event into MongoDB: %v", err)

View File

@ -4,21 +4,15 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"grain/server/db"
"grain/server/events" "grain/server/events"
server "grain/server/types" server "grain/server/types"
"grain/server/utils" "grain/server/utils"
"go.mongodb.org/mongo-driver/mongo"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
var subscriptions = make(map[string]server.Subscription) var subscriptions = make(map[string]server.Subscription)
var client *mongo.Client
func SetClient(mongoClient *mongo.Client) {
client = mongoClient
events.SetClient(mongoClient) // Ensure the events package has the MongoDB client
}
func Handler(ws *websocket.Conn) { func Handler(ws *websocket.Conn) {
defer ws.Close() defer ws.Close()
@ -80,19 +74,51 @@ func handleEvent(ws *websocket.Conn, message []interface{}) {
return return
} }
var evt events.Event var evt server.Event
err = json.Unmarshal(eventBytes, &evt) err = json.Unmarshal(eventBytes, &evt)
if err != nil { if err != nil {
fmt.Println("Error unmarshaling event data:", err) fmt.Println("Error unmarshaling event data:", err)
return return
} }
// Call the HandleEvent function from the events package // Call the HandleKind function from the events package
events.HandleEvent(context.TODO(), evt, ws) HandleKind(context.TODO(), evt, ws)
fmt.Println("Event processed:", evt.ID) fmt.Println("Event processed:", evt.ID)
} }
func HandleKind(ctx context.Context, evt server.Event, ws *websocket.Conn) {
if !utils.CheckSignature(evt) {
sendOKResponse(ws, evt.ID, false, "invalid: signature verification failed")
return
}
collection := db.GetCollection(evt.Kind)
var err error
switch evt.Kind {
case 0:
err = events.HandleKind0(ctx, evt, collection)
case 1:
err = events.HandleKind1(ctx, evt, collection)
default:
err = events.HandleUnknownKind(ctx, evt, collection)
}
if err != nil {
sendOKResponse(ws, evt.ID, false, fmt.Sprintf("error: %v", err))
return
}
sendOKResponse(ws, evt.ID, true, "")
}
func sendOKResponse(ws *websocket.Conn, eventID string, status bool, message string) {
response := []interface{}{"OK", eventID, status, message}
responseBytes, _ := json.Marshal(response)
websocket.Message.Send(ws, string(responseBytes))
}
func handleReq(ws *websocket.Conn, message []interface{}) { func handleReq(ws *websocket.Conn, message []interface{}) {
if len(message) < 3 { if len(message) < 3 {
fmt.Println("Invalid REQ message format") fmt.Println("Invalid REQ message format")
@ -129,7 +155,7 @@ func handleReq(ws *websocket.Conn, message []interface{}) {
fmt.Println("Subscription added:", subID) fmt.Println("Subscription added:", subID)
// Query the database with filters and send back the results // Query the database with filters and send back the results
queriedEvents, err := QueryEvents(filters, client, "grain", "event-kind1") queriedEvents, err := QueryEvents(filters, db.GetClient(), "grain", "event-kind1")
if err != nil { if err != nil {
fmt.Println("Error querying events:", err) fmt.Println("Error querying events:", err)
return return

73
server/utils/checkSig.go Normal file
View File

@ -0,0 +1,73 @@
package utils
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"log"
server "grain/server/types"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
)
func SerializeEvent(evt server.Event) []byte {
eventData := []interface{}{
0,
evt.PubKey,
evt.CreatedAt,
evt.Kind,
evt.Tags,
evt.Content,
}
serializedEvent, _ := json.Marshal(eventData)
return serializedEvent
}
func CheckSignature(evt server.Event) bool {
serializedEvent := SerializeEvent(evt)
hash := sha256.Sum256(serializedEvent)
eventID := hex.EncodeToString(hash[:])
if eventID != evt.ID {
log.Printf("Invalid ID: expected %s, got %s\n", eventID, evt.ID)
return false
}
sigBytes, err := hex.DecodeString(evt.Sig)
if err != nil {
log.Printf("Error decoding signature: %v\n", err)
return false
}
sig, err := schnorr.ParseSignature(sigBytes)
if err != nil {
log.Printf("Error parsing signature: %v\n", err)
return false
}
pubKeyBytes, err := hex.DecodeString(evt.PubKey)
if err != nil {
log.Printf("Error decoding public key: %v\n", err)
return false
}
var pubKey *btcec.PublicKey
if len(pubKeyBytes) == 32 {
// Handle 32-byte public key (x-coordinate only)
pubKey, err = btcec.ParsePubKey(append([]byte{0x02}, pubKeyBytes...))
} else {
// Handle standard compressed or uncompressed public key
pubKey, err = btcec.ParsePubKey(pubKeyBytes)
}
if err != nil {
log.Printf("Error parsing public key: %v\n", err)
return false
}
verified := sig.Verify(hash[:], pubKey)
if !verified {
log.Printf("Signature verification failed for event ID: %s\n", evt.ID)
}
return verified
}