// Proxies.sx Peer SDK - Community Edition (Go)
// =====================================
// Open-source community tool. Provided "AS IS", WITHOUT WARRANTY OF ANY KIND,
// express or implied. MIT License. Maintained by community contributors via
// github.com/bolivian-peru; it is NOT a warranted product of Proxies.sx, and
// the operators of Proxies.sx accept no liability for its use. Run it at your
// own discretion. See the Peer Partner Agreement for the terms of participation.
// =====================================
// A drop-in Go peer that implements the COMPLETE relay protocol, 1:1 with the
// canonical Node reference (agents.proxies.sx/peer/reference-sdk.js, v1.3.1):
//   - Registers with the platform (reuses the same device identity across restarts)
//   - Connects to the nearest relay over WebSocket
//   - Heartbeats every 30s
//   - Handles `tunnel_connect` - opens a local TCP socket to the target
//   - Forwards bytes both directions (fast binary frames)
//   - Reports `tunnel_open_failed` so customers fail fast
//   - Honors server-driven nearest-relay routing (`relay_redirect`)
//   - Opens several parallel WebSockets for throughput
//
// This file mirrors the JS reference function-for-function so the two read
// side-by-side. One dependency: gorilla/websocket.
//
// USAGE:
//   go run reference-sdk.go -key=psx_...
//   API_KEY=psx_... go run reference-sdk.go
//   go build -o proxies-peer reference-sdk.go && ./proxies-peer -key=psx_...
//
// Each process holds wsConnections WebSockets to one relay and serves customer
// tunnels. Run one process per peer device.
package main

import (
	"bytes"
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"log"
	"math/rand"
	"net"
	"net/http"
	"os"
	"os/signal"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"syscall"
	"time"

	"github.com/gorilla/websocket"
)

const sdkVersion = "1.3.3"

// ---------------------------------------------------------------------
// CONFIG - flags whose defaults come from environment variables, so the
// same binary works with `-key=...` flags OR API_KEY=... env (matching the
// Node reference, which is env-driven). No config framework: flags + env.
// ---------------------------------------------------------------------
var (
	apiBase          string
	relayFlag        string
	apiKey           string
	agentName        string
	wallet           string
	country          string
	carrier          string
	connectionMethod string
	wsConnections    int
	stateFile        string
	verbose          bool
)

func envOr(key, dflt string) string {
	if v := os.Getenv(key); v != "" {
		return v
	}
	return dflt
}

func setupConfig() {
	home, err := os.UserHomeDir()
	if err != nil {
		home = "."
	}
	defaultState := filepath.Join(home, ".proxies-peer-state.json")
	// Empty sentinel: when neither -name nor AGENT_NAME is set we resolve the
	// name in acquireIdentity() — reuse the one saved in the state file, else
	// generate a random one ONCE and persist it. (v1.3.2) This makes the SDK
	// stable-by-default: omitting -name no longer spawns a new device per launch.
	defaultName := envOr("AGENT_NAME", "")

	flag.StringVar(&apiBase, "api", envOr("API_BASE", "https://api.proxies.sx/v1"), "Platform API base URL")
	flag.StringVar(&relayFlag, "relay", os.Getenv("RELAY_URL"), "Pin a relay URL (empty = let the server pick the nearest)")
	flag.StringVar(&apiKey, "key", os.Getenv("API_KEY"), "Your Proxies.sx API key (required)")
	flag.StringVar(&agentName, "name", defaultName, "Friendly agent name shown in your dashboard")
	flag.StringVar(&wallet, "wallet", os.Getenv("WALLET"), "Payout wallet address (optional)")
	flag.StringVar(&country, "country", envOr("COUNTRY", "US"), "Your ISO-2 country hint (server still verifies the real IP)")
	flag.StringVar(&carrier, "carrier", envOr("CARRIER", "unknown"), "Your ISP / carrier name (informational)")
	flag.StringVar(&connectionMethod, "method", envOr("CONNECTION_METHOD", "agent"), "Connection-type badge (agent/vps/docker)")
	flag.IntVar(&wsConnections, "connections", atoiOr(envOr("WS_CONNECTIONS", "4"), 4), "Parallel relay sockets (1-8)")
	flag.StringVar(&stateFile, "state", envOr("PEER_STATE_FILE", defaultState), "Path to the saved identity file")
	flag.BoolVar(&verbose, "verbose", os.Getenv("VERBOSE") == "true", "Log every tunnel open/close")
	flag.Parse()

	if apiKey == "" {
		log.Fatalln("[FATAL] API key required. Use -key=psx_... or set API_KEY=psx_...")
	}
	if wsConnections < 1 {
		wsConnections = 1
	} else if wsConnections > 8 {
		wsConnections = 8
	}
}

func atoiOr(s string, dflt int) int {
	if n, err := strconv.Atoi(s); err == nil {
		return n
	}
	return dflt
}

// ---------------------------------------------------------------------
// Shared, refreshable state. activeRelay can change via relay_redirect;
// the identity (token+deviceID) is refreshed on auth-failure re-register.
// All three are read by every socket goroutine, so guard them.
// ---------------------------------------------------------------------
type sharedState struct {
	mu          sync.RWMutex
	token       string
	deviceID    string
	activeRelay string
}

func (s *sharedState) setIdentity(token, deviceID string) {
	s.mu.Lock()
	s.token, s.deviceID = token, deviceID
	s.mu.Unlock()
}
func (s *sharedState) getIdentity() (string, string) {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.token, s.deviceID
}
func (s *sharedState) setRelay(url string) {
	s.mu.Lock()
	s.activeRelay = url
	s.mu.Unlock()
}
func (s *sharedState) getRelay() string {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.activeRelay
}

var (
	state               = &sharedState{}
	relayPinned         bool
	lastRedirectAt      int64 // unix ms
	openSockets         sync.Map
	reregisterScheduled int32
	ourRelayRe          = regexp.MustCompile(`(?i)^wss://[a-z0-9.-]+\.proxies\.sx(/|$)`)
)

// Registration is the /peer/agents/register (and /token/refresh) response.
type Registration struct {
	DeviceId     string          `json:"deviceId"`
	// AgentName is persisted (v1.3.2) so a restart reuses the SAME name even
	// when -name was omitted — re-registration then yields the SAME deviceId
	// (SHA-256(apiKey+name)) instead of a fresh random one. This is what stops
	// a fleet ballooning into duplicate devices when the launch command has no
	// -name. Resolved in acquireIdentity().
	AgentName    string          `json:"agentName"`
	Token        string          `json:"token"`
	Jwt          string          `json:"jwt"`
	RefreshToken string          `json:"refreshToken"`
	Relay        string          `json:"relay"`
	EarningsPerGB json.RawMessage `json:"earningsPerGB"`
}

// WSMessage is the JSON control-plane envelope.
type WSMessage struct {
	Type    string          `json:"type"`
	Payload json.RawMessage `json:"payload,omitempty"`
}

func ts() string { return time.Now().Format("2006-01-02 15:04:05") }
func logf(f string, a ...interface{}) { log.Printf("["+ts()+"] "+f, a...) }
func vlogf(f string, a ...interface{}) {
	if verbose {
		log.Printf("["+ts()+"] "+f, a...)
	}
}

func main() {
	setupConfig()
	relayPinned = relayFlag != ""

	ctx, cancel := context.WithCancel(context.Background())
	sig := make(chan os.Signal, 1)
	signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
	go func() {
		<-sig
		logf("Stopping. Goodbye.")
		cancel()
		os.Exit(0)
	}()

	// 1. Acquire identity (resume saved one, else register).
	token, deviceID, relay, err := acquireIdentity()
	if err != nil {
		log.Fatalf("[FATAL] register failed: %v", err)
	}
	state.setIdentity(token, deviceID)

	// Relay precedence: manual pin > server geo-assignment > built-in default.
	switch {
	case relayPinned:
		state.setRelay(relayFlag)
	case relay != "":
		state.setRelay(relay)
	default:
		state.setRelay("wss://relay.proxies.sx")
	}

	logf("[REGISTERED] device=%s relay=%s sockets=%d", deviceID, state.getRelay(), wsConnections)
	logf("Sharing bandwidth as %q. Leave this running. Ctrl+C to stop.", agentName)

	// 2. Launch a FIXED set of socket goroutines. Each one reconnects itself;
	//    no new goroutines are ever spawned (this is the fix vs a recursive
	//    re-register that leaks goroutine generations).
	var wg sync.WaitGroup
	for i := 0; i < wsConnections; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			connectLoop(ctx)
		}()
	}
	wg.Wait()
}

// connectLoop runs ONE socket for the life of the process: connect, serve,
// reconnect on close with backoff. On an auth-failure close it triggers a
// single shared re-register (which updates the token every socket reads).
func connectLoop(ctx context.Context) {
	for {
		select {
		case <-ctx.Done():
			return
		default:
		}
		token, deviceID := state.getIdentity()
		err := connect(ctx, token, deviceID)
		if isAuthClose(err) {
			scheduleReregister()
		}
		select {
		case <-ctx.Done():
			return
		case <-time.After(5*time.Second + time.Duration(rand.Intn(2000))*time.Millisecond):
		}
	}
}

func isAuthClose(err error) bool {
	if err == nil {
		return false
	}
	return websocket.IsCloseError(err, 4001, 4002) ||
		strings.Contains(err.Error(), "4001") || strings.Contains(err.Error(), "4002")
}

// connect opens one WebSocket, serves tunnels, and returns when it closes.
func connect(ctx context.Context, token, deviceID string) error {
	url := state.getRelay()
	dialer := websocket.Dialer{
		Subprotocols:     []string{"token." + token}, // JWT in Sec-WebSocket-Protocol, never the URL
		HandshakeTimeout: 10 * time.Second,
	}
	ws, resp, err := dialer.Dial(url, nil)
	if err != nil {
		if resp != nil {
			return fmt.Errorf("dial failed: HTTP %d", resp.StatusCode)
		}
		return err
	}
	defer ws.Close()
	logf("[CONNECTED] device=%s relay=%s", deviceID, url)
	openSockets.Store(ws, true)
	defer openSockets.Delete(ws)

	var writeMu sync.Mutex // gorilla forbids concurrent writes; one mutex per socket
	send := func(m WSMessage) error {
		b, _ := json.Marshal(m)
		writeMu.Lock()
		defer writeMu.Unlock()
		return ws.WriteMessage(websocket.TextMessage, b)
	}
	sendBinary := func(frame []byte) error {
		writeMu.Lock()
		defer writeMu.Unlock()
		return ws.WriteMessage(websocket.BinaryMessage, frame)
	}

	tunnels := sync.Map{} // sessionId -> net.Conn
	defer tunnels.Range(func(_, v interface{}) bool {
		if c, ok := v.(net.Conn); ok {
			c.Close()
		}
		return true
	})

	// Announce capabilities + initial state.
	_ = send(WSMessage{Type: "device_info", Payload: json.RawMessage(fmt.Sprintf(
		`{"country":%q,"carrier":%q,"currentIp":"auto","protocol":"binary-v1","supportsRelayRedirect":true,"sdkVersion":%q}`,
		country, carrier, sdkVersion))})

	// Heartbeat every 30s; stop it when this socket closes.
	hbCtx, hbCancel := context.WithCancel(ctx)
	defer hbCancel()
	go func() {
		t := time.NewTicker(30 * time.Second)
		defer t.Stop()
		for {
			select {
			case <-hbCtx.Done():
				return
			case <-t.C:
				_ = send(WSMessage{Type: "heartbeat"})
			}
		}
	}()

	for {
		msgType, raw, err := ws.ReadMessage()
		if err != nil {
			logf("[CLOSED] %v", err)
			return err
		}

		// Hot path: binary tunnel frames. [0x01 data | 0x03 close][sidLen][sid][payload]
		if msgType == websocket.BinaryMessage && len(raw) >= 2 {
			action, sidLen := raw[0], int(raw[1])
			if len(raw) < 2+sidLen {
				continue
			}
			sid := string(raw[2 : 2+sidLen])
			payload := raw[2+sidLen:]
			if v, ok := tunnels.Load(sid); ok {
				sock := v.(net.Conn)
				if action == 0x01 {
					_, _ = sock.Write(payload)
				} else if action == 0x03 {
					sock.Close()
					tunnels.Delete(sid)
				}
			}
			continue
		}

		var msg WSMessage
		if json.Unmarshal(raw, &msg) != nil {
			continue
		}
		switch msg.Type {
		case "connected":
			logf("[ACK] relay confirmed connection")
		case "heartbeat_ack":
			// no-op
		case "relay_redirect":
			var p struct{ Relay, Reason string }
			if json.Unmarshal(msg.Payload, &p) == nil {
				handleRedirect(p.Relay, p.Reason)
			}
		case "tunnel_connect":
			var p struct {
				SessionId string      `json:"sessionId"`
				Host      string      `json:"host"`
				Port      interface{} `json:"port"`
			}
			if json.Unmarshal(msg.Payload, &p) == nil {
				go handleTunnelConnect(send, sendBinary, &tunnels, p.SessionId, p.Host, portToString(p.Port))
			}
		case "tunnel_close":
			var p struct{ SessionId string `json:"sessionId"` }
			if json.Unmarshal(msg.Payload, &p) == nil {
				if v, ok := tunnels.Load(p.SessionId); ok {
					v.(net.Conn).Close()
					tunnels.Delete(p.SessionId)
				}
			}
		case "rotate_ip_request":
			// Desktop/server peers can't toggle a carrier IP; report unsupported.
			var p struct{ RequestId string `json:"requestId"` }
			if json.Unmarshal(msg.Payload, &p) == nil {
				_ = send(WSMessage{Type: "rotation_complete", Payload: json.RawMessage(
					fmt.Sprintf(`{"requestId":%q,"success":false,"error":"not_supported_on_server_peer"}`, p.RequestId))})
			}
		default:
			// Unknown control messages are ignored, never fatal.
		}
	}
}

// portToString accepts a JSON number or string port and returns "443" etc.
func portToString(p interface{}) string {
	switch v := p.(type) {
	case float64:
		return strconv.Itoa(int(v))
	case string:
		return v
	default:
		return fmt.Sprintf("%v", v)
	}
}

// handleTunnelConnect opens a raw TCP socket to the target and pipes bytes.
// For HTTPS (443) this is a PLAIN TCP relay - TLS stays end-to-end between the
// customer and the target. Do NOT terminate TLS here.
func handleTunnelConnect(send func(WSMessage) error, sendBinary func([]byte) error, tunnels *sync.Map, sid, host, port string) {
	vlogf("[TUNNEL_CONNECT] %s %s:%s", sid, host, port)
	start := time.Now()
	sock, err := net.DialTimeout("tcp", net.JoinHostPort(host, port), 30*time.Second)
	if err != nil {
		vlogf("[TUNNEL_FAIL] %s %v", sid, err)
		_ = send(WSMessage{Type: "tunnel_open_failed", Payload: json.RawMessage(fmt.Sprintf(
			`{"sessionId":%q,"reason":%q,"detail":%s,"durationMs":%d}`,
			sid, failReason(err), strconv.Quote(err.Error()), time.Since(start).Milliseconds()))})
		return
	}
	if tcp, ok := sock.(*net.TCPConn); ok {
		_ = tcp.SetNoDelay(true)
	}
	tunnels.Store(sid, sock)
	vlogf("[TCP_OPEN_OK] %s", sid)
	_ = send(WSMessage{Type: "tunnel_connected", Payload: json.RawMessage(fmt.Sprintf(`{"sessionId":%q}`, sid))})

	// Forward target->peer back to the relay until the socket closes.
	defer func() {
		sock.Close()
		tunnels.Delete(sid)
		_ = send(WSMessage{Type: "tunnel_closed", Payload: json.RawMessage(fmt.Sprintf(`{"sessionId":%q}`, sid))})
	}()
	sidBytes := []byte(sid)
	buf := make([]byte, 32*1024)
	for {
		n, err := sock.Read(buf)
		if n > 0 {
			frame := make([]byte, 2+len(sidBytes)+n)
			frame[0] = 0x01
			frame[1] = byte(len(sidBytes))
			copy(frame[2:], sidBytes)
			copy(frame[2+len(sidBytes):], buf[:n])
			if sendBinary(frame) != nil {
				return
			}
		}
		if err != nil {
			return
		}
	}
}

func failReason(err error) string {
	if ne, ok := err.(net.Error); ok && ne.Timeout() {
		return "timeout"
	}
	s := err.Error()
	switch {
	case strings.Contains(s, "refused"):
		return "connection_refused"
	case strings.Contains(s, "no such host"):
		return "dns_failed"
	case strings.Contains(s, "timeout"):
		return "timeout"
	default:
		return "unknown"
	}
}

// handleRedirect switches the whole pool to a nearer relay (server-driven).
func handleRedirect(target, reason string) {
	if relayPinned {
		return
	}
	if !ourRelayRe.MatchString(target) {
		return
	}
	if target == state.getRelay() {
		return
	}
	now := time.Now().UnixMilli()
	if now-atomic.LoadInt64(&lastRedirectAt) < 60_000 { // anti-flap
		return
	}
	atomic.StoreInt64(&lastRedirectAt, now)
	logf("[REDIRECT] %s -> %s (%s)", state.getRelay(), target, reason)
	state.setRelay(target)
	// Close every socket; each connectLoop reconnects to the new relay.
	openSockets.Range(func(k, _ interface{}) bool {
		if ws, ok := k.(*websocket.Conn); ok {
			_ = ws.WriteControl(websocket.CloseMessage,
				websocket.FormatCloseMessage(4100, "relay_redirect"), time.Now().Add(time.Second))
			ws.Close()
		}
		return true
	})
}

// scheduleReregister refreshes the shared identity ONCE on auth failure, so N
// sockets closing on an expired token trigger one re-register, not N.
func scheduleReregister() {
	if !atomic.CompareAndSwapInt32(&reregisterScheduled, 0, 1) {
		return
	}
	go func() {
		time.Sleep(5 * time.Second)
		defer atomic.StoreInt32(&reregisterScheduled, 0)
		token, deviceID, relay, err := acquireIdentity()
		if err != nil {
			logf("[REREGISTER] failed: %v (will retry on next reconnect)", err)
			return
		}
		state.setIdentity(token, deviceID)
		if !relayPinned && relay != "" {
			state.setRelay(relay)
		}
		logf("[REREGISTER] refreshed identity device=%s", deviceID)
	}()
}

// ---------------------------------------------------------------------
// REGISTER / IDENTITY
// ---------------------------------------------------------------------
func acquireIdentity() (token, deviceID, relay string, err error) {
	saved := loadState()

	// Resolve a STABLE name when the user didn't pass -name / AGENT_NAME.
	// Reuse the name saved in the state file so a restart re-registers with the
	// SAME deviceId (SHA-256(apiKey+name)) instead of a fresh random one — even
	// if the refresh below fails (expired token, or the device was cleaned up
	// server-side). Generate-and-persist only on a truly first run. This is the
	// single most important guard against duplicate-device churn.
	if agentName == "" {
		if saved != nil && saved.AgentName != "" {
			agentName = saved.AgentName
			logf("[IDENTITY] reusing saved name %q (no -name given)", agentName)
		} else {
			agentName = fmt.Sprintf("go-%x", rand.New(rand.NewSource(time.Now().UnixNano())).Int31())
			logf("[IDENTITY] generated name %q and persisting it — pass -name to choose your own", agentName)
		}
	}

	if saved != nil && saved.DeviceId != "" && saved.RefreshToken != "" {
		if t := refreshSavedToken(saved.DeviceId, saved.RefreshToken); t != "" {
			logf("[RESUMED] device=%s (reused saved identity)", saved.DeviceId)
			// Backfill the name into state if an older file predates v1.3.2.
			if saved.AgentName == "" {
				saved.AgentName = agentName
				saveState(saved)
			}
			return t, saved.DeviceId, saved.Relay, nil
		}
	}
	r, err := register()
	if err != nil {
		return "", "", "", err
	}
	r.AgentName = agentName // persist the resolved name → future launches are stable
	token = r.Jwt
	if token == "" {
		token = r.Token
	}
	saveState(r)
	return token, r.DeviceId, r.Relay, nil
}

var sharedHTTP = &http.Client{Timeout: 15 * time.Second}

func register() (*Registration, error) {
	body := map[string]interface{}{
		"name":             agentName,
		"type":             "custom",
		"connectionMethod": connectionMethod,
		"sdkVersion":       sdkVersion, // reference sends this; do not omit
		"apiKey":           apiKey,
	}
	if wallet != "" {
		body["walletAddress"] = wallet // only when set: empty fails validation
	}
	b, _ := json.Marshal(body)
	resp, err := sharedHTTP.Post(apiBase+"/peer/agents/register", "application/json", bytes.NewReader(b))
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	raw, _ := io.ReadAll(resp.Body)
	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
		return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(raw))
	}
	var r Registration
	if err := json.Unmarshal(raw, &r); err != nil {
		return nil, err
	}
	return &r, nil
}

func refreshSavedToken(deviceID, refreshToken string) string {
	b, _ := json.Marshal(map[string]string{"refreshToken": refreshToken})
	resp, err := sharedHTTP.Post(fmt.Sprintf("%s/peer/token/%s/refresh", apiBase, deviceID), "application/json", bytes.NewReader(b))
	if err != nil {
		return ""
	}
	defer resp.Body.Close()
	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
		return ""
	}
	var m map[string]string
	if json.NewDecoder(resp.Body).Decode(&m) == nil {
		if t := m["token"]; t != "" {
			return t
		}
		if t := m["jwt"]; t != "" {
			return t
		}
	}
	return ""
}

func loadState() *Registration {
	data, err := os.ReadFile(stateFile)
	if err != nil {
		return nil
	}
	var r Registration
	if json.Unmarshal(data, &r) != nil {
		return nil
	}
	return &r
}

func saveState(r *Registration) {
	if data, err := json.Marshal(r); err == nil {
		_ = os.WriteFile(stateFile, data, 0o600)
	}
}
