Răsfoiți Sursa

feat(web): 添加 WebSocket 支持

- 新增 WebSocket 配置相关结构体和接口
- 实现 WebSocket 配置的创建、编辑和启用禁用功能
- 在 Web 转发服务中集成 WebSocket 支持
- 添加获取 CDN Web 配置 ID 的功能
fusu 3 săptămâni în urmă
părinte
comite
474f3c7489

+ 13 - 0
api/v1/cdn.go

@@ -216,3 +216,16 @@ type CcConfig struct {
 }
 }
 
 
 
 
+type WebSocket struct {
+	WebsocketId          int64  `json:"websocketId" form:"websocketId"`
+	HandshakeTimeoutJSON []byte `json:"handshakeTimeoutJSON" form:"handshakeTimeoutJSON"`
+	AllowAllOrigins      bool   `json:"allowAllOrigins" form:"allowAllOrigins"`
+	AllowedOrigins       []string `json:"allowedOrigins" form:"allowedOrigins"`
+	RequestSameOrigin    bool   `json:"requestSameOrigin" form:"requestSameOrigin"`
+	RequestOrigin        string `json:"requestOrigin" form:"requestOrigin"`
+}
+
+type HandshakeTimeoutJSON struct {
+	Unit  string `json:"unit" form:"unit"`
+	Count int    `json:"count" form:"count"`
+}

+ 1 - 0
cmd/server/wire/wire.go

@@ -83,6 +83,7 @@ var serviceSet = wire.NewSet(
 	service.NewAllowAndDenyIpService,
 	service.NewAllowAndDenyIpService,
 	service.NewProxyService,
 	service.NewProxyService,
 	service.NewSslCertService,
 	service.NewSslCertService,
+	service.NewWebsocketService,
 )
 )
 
 
 var handlerSet = wire.NewSet(
 var handlerSet = wire.NewSet(

+ 3 - 2
cmd/server/wire/wire_gen.go

@@ -74,7 +74,8 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 	proxyRepository := repository.NewProxyRepository(repositoryRepository)
 	proxyRepository := repository.NewProxyRepository(repositoryRepository)
 	proxyService := service.NewProxyService(serviceService, proxyRepository, cdnService)
 	proxyService := service.NewProxyService(serviceService, proxyRepository, cdnService)
 	sslCertService := service.NewSslCertService(serviceService, webForwardingRepository, cdnService)
 	sslCertService := service.NewSslCertService(serviceService, webForwardingRepository, cdnService)
-	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gateWayGroupIpRepository, gatewayGroupRepository, globalLimitRepository, cdnService, proxyService, sslCertService)
+	websocketService := service.NewWebsocketService(serviceService, cdnService, webForwardingRepository)
+	webForwardingService := service.NewWebForwardingService(serviceService, requiredService, webForwardingRepository, crawlerService, parserService, wafFormatterService, aoDunService, rabbitMQ, gateWayGroupIpRepository, gatewayGroupRepository, globalLimitRepository, cdnService, proxyService, sslCertService, websocketService)
 	webForwardingHandler := handler.NewWebForwardingHandler(handlerHandler, webForwardingService)
 	webForwardingHandler := handler.NewWebForwardingHandler(handlerHandler, webForwardingService)
 	webLimitRepository := repository.NewWebLimitRepository(repositoryRepository)
 	webLimitRepository := repository.NewWebLimitRepository(repositoryRepository)
 	webLimitService := service.NewWebLimitService(serviceService, webLimitRepository, requiredService, parserService, crawlerService, hostService)
 	webLimitService := service.NewWebLimitService(serviceService, webLimitRepository, requiredService, parserService, crawlerService, hostService)
@@ -112,7 +113,7 @@ func NewWire(viperViper *viper.Viper, logger *log.Logger) (*app.App, func(), err
 
 
 var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewAllowAndDenyIpRepository, repository.NewProxyRepository)
 var repositorySet = wire.NewSet(repository.NewDB, repository.NewRedis, repository.NewCasbinEnforcer, repository.NewMongoClient, repository.NewMongoDB, repository.NewRabbitMQ, repository.NewRepository, repository.NewTransaction, repository.NewAdminRepository, repository.NewUserRepository, repository.NewGameShieldRepository, repository.NewGameShieldPublicIpRepository, repository.NewWebForwardingRepository, repository.NewTcpforwardingRepository, repository.NewUdpForWardingRepository, repository.NewGameShieldUserIpRepository, repository.NewWebLimitRepository, repository.NewTcpLimitRepository, repository.NewUdpLimitRepository, repository.NewGameShieldBackendRepository, repository.NewGameShieldSdkIpRepository, repository.NewHostRepository, repository.NewGlobalLimitRepository, repository.NewGatewayGroupRepository, repository.NewGateWayGroupIpRepository, repository.NewCdnRepository, repository.NewAllowAndDenyIpRepository, repository.NewProxyRepository)
 
 
-var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService, service.NewAllowAndDenyIpService, service.NewProxyService, service.NewSslCertService)
+var serviceSet = wire.NewSet(service.NewService, service.NewAoDunService, service.NewUserService, service.NewAdminService, service.NewGameShieldService, service.NewGameShieldPublicIpService, service.NewDuedateService, service.NewFormatterService, service.NewParserService, service.NewRequiredService, service.NewCrawlerService, service.NewWebForwardingService, service.NewTcpforwardingService, service.NewUdpForWardingService, service.NewGameShieldUserIpService, service.NewWebLimitService, service.NewTcpLimitService, service.NewUdpLimitService, service.NewGameShieldBackendService, service.NewGameShieldSdkIpService, service.NewHostService, service.NewGlobalLimitService, service.NewGatewayGroupService, service.NewWafFormatterService, service.NewGateWayGroupIpService, service.NewRequestService, service.NewCdnService, service.NewAllowAndDenyIpService, service.NewProxyService, service.NewSslCertService, service.NewWebsocketService)
 
 
 var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler, handler.NewAllowAndDenyIpHandler)
 var handlerSet = wire.NewSet(handler.NewHandler, handler.NewUserHandler, handler.NewAdminHandler, handler.NewGameShieldHandler, handler.NewGameShieldPublicIpHandler, handler.NewWebForwardingHandler, handler.NewTcpforwardingHandler, handler.NewUdpForWardingHandler, handler.NewGameShieldUserIpHandler, handler.NewWebLimitHandler, handler.NewTcpLimitHandler, handler.NewUdpLimitHandler, handler.NewGameShieldBackendHandler, handler.NewGameShieldSdkIpHandler, handler.NewHostHandler, handler.NewGlobalLimitHandler, handler.NewGatewayGroupHandler, handler.NewGateWayGroupIpHandler, handler.NewAllowAndDenyIpHandler)
 
 

+ 12 - 0
internal/repository/webforwarding.go

@@ -36,6 +36,8 @@ type WebForwardingRepository interface {
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
 	GetWebConfigId(ctx context.Context, id int64) (int64, error)
 	// 获取域名
 	// 获取域名
 	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error)
 	GetDomainByHostIdPort(ctx context.Context, hostId int64, port string) ([]v1.Domain, error)
+	// 获取CDN的web配置的id
+	GetWebId(ctx context.Context, serverId int64) (int64, error)
 }
 }
 
 
 func NewWebForwardingRepository(
 func NewWebForwardingRepository(
@@ -326,3 +328,13 @@ func (r *webForwardingRepository) GetDomainByHostIdPort(ctx context.Context, hos
 	return domains, nil
 	return domains, nil
 
 
 }
 }
+
+// 获取CDN的web配置的id
+func (r *webForwardingRepository) GetWebId(ctx context.Context, serverId int64) (int64, error) {
+	var webId int64
+	if err := r.DBWithName(ctx,"cdn").Table("cloud_servers").Where("id = ?", serverId).Select("webId").Scan(&webId).Error; err != nil {
+		return 0, err
+	}
+	return webId, nil
+
+}

+ 79 - 0
internal/service/cdn.go

@@ -46,6 +46,12 @@ type CdnService interface {
 	EditWebLog(ctx context.Context,webId int64, req v1.WebLog) error
 	EditWebLog(ctx context.Context,webId int64, req v1.WebLog) error
 	// 修改CC配置
 	// 修改CC配置
 	EditCcConfig(ctx context.Context,webId int64, req v1.CcConfig) error
 	EditCcConfig(ctx context.Context,webId int64, req v1.CcConfig) error
+	// 添加webSocket
+	AddWebSockets(ctx context.Context, req v1.WebSocket) (int64,error)
+	// 修改webSocket
+	EditWebSockets(ctx context.Context, req v1.WebSocket) error
+	// 启用webSocket
+	EditHTTPWebWebsocket(ctx context.Context,websocketId int64,websocketJSON []byte) error
 }
 }
 
 
 func NewCdnService(
 func NewCdnService(
@@ -849,4 +855,77 @@ func (s *cdnService) EditCcConfig(ctx context.Context,webId int64, req v1.CcConf
 		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
 	}
 	}
 	return nil
 	return nil
+}
+
+
+// 创建websockets配置
+func (s *cdnService) AddWebSockets(ctx context.Context, req v1.WebSocket) (int64,error) {
+	formData := map[string]interface{}{
+		"handshakeTimeoutJSON": req.HandshakeTimeoutJSON,
+		"allowAllOrigins"      : req.AllowAllOrigins,
+		"allowedOrigins"       : req.AllowedOrigins,
+		"requestSameOrigin"    : req.RequestSameOrigin,
+		"requestOrigin"        : req.RequestOrigin,
+	}
+	apiUrl := s.Url + "HTTPWebsocketService/createHTTPWebsocket"
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
+	if err != nil {
+		return 0,err
+	}
+	type WebSocket struct {
+		WebSocketId int64 `json:"websocketId"`
+	}
+	var res v1.GeneralResponse[WebSocket]
+	if err := json.Unmarshal(resBody, &res); err != nil {
+		return 0,fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+	}
+	if res.Code != 200 {
+		return 0,fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
+	}
+	return res.Data.WebSocketId,nil
+}
+
+func (s *cdnService) EditWebSockets(ctx context.Context,req v1.WebSocket) error {
+	formData := map[string]interface{}{
+		"websocketId"          : req.WebsocketId,
+		"handshakeTimeoutJSON": req.HandshakeTimeoutJSON,
+		"allowAllOrigins"      : req.AllowAllOrigins,
+		"allowedOrigins"       : req.AllowedOrigins,
+		"requestSameOrigin"    : req.RequestSameOrigin,
+		"requestOrigin"        : req.RequestOrigin,
+	}
+	apiUrl := s.Url + "HTTPWebsocketService/updateHTTPWebsocket"
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
+	if err != nil {
+		return err
+	}
+	var res v1.GeneralResponse[any]
+	if err := json.Unmarshal(resBody, &res); err != nil {
+		return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+	}
+	if res.Code != 200 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
+	}
+	return nil
+}
+
+// 启用/禁用websockets
+func (s *cdnService) EditHTTPWebWebsocket(ctx context.Context,websocketId int64,websocketJSON []byte) error {
+	formData := map[string]interface{}{
+		"httpWebId"          : websocketId,
+		"websocketJSON":       websocketJSON,
+	}
+	apiUrl := s.Url + "HTTPWebService/updateHTTPWebWebsocket"
+	resBody, err := s.sendDataWithTokenRetry(ctx, formData, apiUrl)
+	if err != nil {
+		return err
+	}
+	var res v1.GeneralResponse[any]
+	if err := json.Unmarshal(resBody, &res); err != nil {
+		return fmt.Errorf("反序列化响应 JSON 失败 (内容: %s): %w", string(resBody), err)
+	}
+	if res.Code != 200 {
+		return fmt.Errorf("API 错误: code %d, msg '%s'", res.Code, res.Message)
+	}
+	return nil
 }
 }

+ 21 - 1
internal/service/webforwarding.go

@@ -37,6 +37,7 @@ func NewWebForwardingService(
 	cdn CdnService,
 	cdn CdnService,
 	proxy ProxyService,
 	proxy ProxyService,
 	sslCert SslCertService,
 	sslCert SslCertService,
+	websocket WebsocketService,
 ) WebForwardingService {
 ) WebForwardingService {
 	return &webForwardingService{
 	return &webForwardingService{
 		Service:                 service,
 		Service:                 service,
@@ -53,6 +54,7 @@ func NewWebForwardingService(
 		globalLimitRep:          globalLimitRep,
 		globalLimitRep:          globalLimitRep,
 		proxy:                   proxy,
 		proxy:                   proxy,
 		sslCert:               	 sslCert,
 		sslCert:               	 sslCert,
+		websocket:               websocket,
 	}
 	}
 }
 }
 
 
@@ -78,6 +80,7 @@ type webForwardingService struct {
 	globalLimitRep          repository.GlobalLimitRepository
 	globalLimitRep          repository.GlobalLimitRepository
 	proxy                   ProxyService
 	proxy                   ProxyService
 	sslCert                 SslCertService
 	sslCert                 SslCertService
+	websocket               WebsocketService
 }
 }
 
 
 func (s *webForwardingService) require(ctx context.Context, req v1.GlobalRequire) (v1.GlobalRequire, error) {
 func (s *webForwardingService) require(ctx context.Context, req v1.GlobalRequire) (v1.GlobalRequire, error) {
@@ -343,12 +346,19 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 		return err
 		return err
 	}
 	}
 
 
+	var protocol string
+	if req.WebForwardingData.IsHttps == isHttps {
+		protocol = "https"
+	}else{
+		protocol = "http"
+	}
 	// 验证端口重复
 	// 验证端口重复
-	err = s.wafformatter.VerifyPort(ctx,"http", int64(req.WebForwardingData.Id), req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
+	err = s.wafformatter.VerifyPort(ctx, protocol, int64(req.WebForwardingData.Id), req.WebForwardingData.Port, int64(require.HostId), req.WebForwardingData.Domain)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
+
 	// 添加证书
 	// 添加证书
 	if req.WebForwardingData.IsHttps == isHttps {
 	if req.WebForwardingData.IsHttps == isHttps {
 		sslCertId, err := s.sslCert.AddSSLCert(ctx, v1.SSL{
 		sslCertId, err := s.sslCert.AddSSLCert(ctx, v1.SSL{
@@ -387,6 +397,16 @@ func (s *webForwardingService) AddWebForwarding(ctx context.Context, req *v1.Web
 	}
 	}
 
 
 
 
+	// 开启websocket
+	websocketId, err := s.websocket.AddWebsocket(ctx)
+	if err != nil {
+		return err
+	}
+	if err := s.websocket.EnableOrDisable(ctx, webId, websocketId, true, false); err != nil {
+		return err
+	}
+
+
 	// 添加源站
 	// 添加源站
 	cdnOriginIds := make(map[string]int64)
 	cdnOriginIds := make(map[string]int64)
 	for _, v := range req.WebForwardingData.BackendList {
 	for _, v := range req.WebForwardingData.BackendList {