// Copyright CloudQuery Authors
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.

package client

import (
	"context"
	"errors"
	"fmt"
	"os"
	"strings"
	"sync"

	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
	"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
	"github.com/Azure/azure-sdk-for-go/sdk/azcore/log"
	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
	"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/billing/armbilling"
	"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
	"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions"
	"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage"
	"github.com/cloudquery/cloudquery/plugins/source/azure/client/spec"
	"github.com/cloudquery/plugin-sdk/v4/schema"
	"github.com/rs/zerolog"
	"github.com/samber/lo"
	"github.com/thoas/go-funk"
	"golang.org/x/sync/errgroup"
)

type NewClientFn func(context.Context, zerolog.Logger, *spec.Spec) (schema.ClientMeta, error)

type Client struct {
	subscriptions []string

	// SubscriptionsObjects is to cache full objects returned from ListSubscriptions on initialisation
	SubscriptionsObjects []*armsubscriptions.Subscription

	// ResourceGroups is to cache full objects returned from ListResourceGroups on initialisation,
	// as a map from subscription ID to list of resource groups.
	ResourceGroups map[string][]*armresources.ResourceGroup

	// SubscriptionsLocations is to cache full objects returned from ListLocations on initialisation,
	// as a map from subscription ID to list of locations.
	SubscriptionsLocations map[string][]*armsubscriptions.Location

	logger               zerolog.Logger
	registeredNamespaces map[string]map[string]bool
	// this is set by table client multiplexer
	SubscriptionId string
	// this is set by table client multiplexer (SubscriptionResourceGroupMultiplexRegisteredNamespace)
	ResourceGroup string
	Creds         azcore.TokenCredential
	Options       *arm.ClientOptions

	Spec            *spec.Spec
	BillingAccounts []*armbilling.Account
	BillingAccount  *armbilling.Account
	BillingProfile  *armbilling.Profile
	BillingPeriods  map[string][]*armbilling.Period
	BillingPeriod   *armbilling.Period

	QuotaResourceProvider string
	Location              string

	storageAccountKeys *sync.Map
}

func (c *Client) discoverSubscriptions(ctx context.Context) error {
	c.subscriptions = make([]string, 0)
	subscriptionClient, err := armsubscriptions.NewClient(c.Creds, c.Options)
	if err != nil {
		return err
	}
	pager := subscriptionClient.NewListPager(nil)
	for pager.More() {
		page, err := pager.NextPage(ctx)
		if err != nil {
			return err
		}
		// we record all returned values, even disabled
		c.SubscriptionsObjects = append(c.SubscriptionsObjects, page.Value...)
		for _, sub := range page.Value {
			if *sub.State == armsubscriptions.SubscriptionStateEnabled {
				c.subscriptions = append(c.subscriptions, strings.TrimPrefix(*sub.ID, "/subscriptions/"))
			}
		}
	}

	return nil
}

func (c *Client) getResourceGroupsForSubscription(ctx context.Context, subscriptionId string) ([]*armresources.ResourceGroup, error) {
	cl, err := armresources.NewResourceGroupsClient(subscriptionId, c.Creds, c.Options)
	if err != nil {
		return nil, fmt.Errorf("failed to create resource group client: %w", err)
	}
	var groups []*armresources.ResourceGroup
	pager := cl.NewListPager(&armresources.ResourceGroupsClientListOptions{})
	for pager.More() {
		page, err := pager.NextPage(ctx)
		if err != nil {
			return nil, fmt.Errorf("failed to list resource groups: %w", err)
		}
		if len(page.Value) == 0 {
			continue
		}
		groups = append(groups, page.Value...)
	}

	return groups, nil
}

func (c *Client) getRegisteredProvidersForSubscription(ctx context.Context, subscriptionId string) ([]*armresources.Provider, error) {
	providerClient, err := armresources.NewProvidersClient(subscriptionId, c.Creds, c.Options)
	if err != nil {
		return nil, fmt.Errorf("failed to create provider client: %w", err)
	}
	var providers []*armresources.Provider
	providerPager := providerClient.NewListPager(nil)
	for providerPager.More() {
		providerPage, err := providerPager.NextPage(ctx)
		if err != nil {
			return nil, fmt.Errorf("failed to list providers: %w", err)
		}
		if len(providerPage.Value) == 0 {
			continue
		}
		for _, p := range providerPage.Value {
			if p.RegistrationState != nil && *p.RegistrationState == "Registered" {
				providers = append(providers, p)
			}
		}
	}
	return providers, nil
}

func (c *Client) getLocationsForSubscription(ctx context.Context, subscriptionId string) ([]*armsubscriptions.Location, error) {
	svc, err := armsubscriptions.NewClient(c.Creds, c.Options)
	if err != nil {
		return nil, fmt.Errorf("failed to create subscription client: %w", err)
	}
	var locations []*armsubscriptions.Location
	pager := svc.NewListLocationsPager(subscriptionId, nil)
	for pager.More() {
		page, err := pager.NextPage(ctx)
		if err != nil {
			return nil, fmt.Errorf("failed to list locations: %w", err)
		}
		if len(page.Value) == 0 {
			continue
		}
		locations = append(locations, page.Value...)
	}

	return locations, nil
}

func (c *Client) discoverBillingAccounts(ctx context.Context) error {
	accounts := make([]*armbilling.Account, 0)
	svc, err := armbilling.NewAccountsClient(c.Creds, c.Options)
	if err != nil {
		return err
	}
	pager := svc.NewListPager(&armbilling.AccountsClientListOptions{Expand: to.Ptr("soldTo,billingProfiles,billingProfiles/invoiceSections")})
	for pager.More() {
		p, err := pager.NextPage(ctx)
		if err != nil {
			return err
		}
		accounts = append(accounts, p.Value...)
	}
	c.BillingAccounts = accounts
	return nil
}

func (c *Client) discoverBillingPeriods(ctx context.Context) error {
	billingPeriods := make(map[string][]*armbilling.Period, len(c.subscriptions))
	errorGroup, gtx := errgroup.WithContext(ctx)
	errorGroup.SetLimit(c.Spec.DiscoveryConcurrency)

	periodsLock := sync.Mutex{}

	for _, subID := range c.subscriptions {
		subID := subID
		errorGroup.Go(func() error {
			periods := make([]*armbilling.Period, 0)
			svc, err := armbilling.NewPeriodsClient(subID, c.Creds, c.Options)
			if err != nil {
				return err
			}
			pager := svc.NewListPager(nil)
			for pager.More() {
				p, err := pager.NextPage(gtx)
				if err != nil {
					return err
				}
				periods = append(periods, p.Value...)
			}

			periodsLock.Lock()
			defer periodsLock.Unlock()
			billingPeriods[subID] = periods

			return nil
		})
	}
	err := errorGroup.Wait()
	if err != nil {
		return err
	}
	c.BillingPeriods = billingPeriods
	return nil
}

func (c *Client) discoverResourceGroups(ctx context.Context) error {
	c.ResourceGroups = make(map[string][]*armresources.ResourceGroup, len(c.subscriptions))
	c.registeredNamespaces = make(map[string]map[string]bool, len(c.subscriptions))
	c.SubscriptionsLocations = make(map[string][]*armsubscriptions.Location, len(c.subscriptions))

	groupsLock, namespacesLock, locationsLock := sync.Mutex{}, sync.Mutex{}, sync.Mutex{}

	errorGroup, gtx := errgroup.WithContext(ctx)
	errorGroup.SetLimit(c.Spec.DiscoveryConcurrency)
	for _, subID := range c.subscriptions {
		subID := subID
		errorGroup.Go(func() error {
			groups, err := c.getResourceGroupsForSubscription(gtx, subID)
			if err != nil {
				return err
			}
			groupsLock.Lock()
			defer groupsLock.Unlock()
			c.ResourceGroups[subID] = groups

			return nil
		})

		errorGroup.Go(func() error {
			providers, err := c.getRegisteredProvidersForSubscription(gtx, subID)
			if err != nil {
				return err
			}

			namespacesLock.Lock()
			defer namespacesLock.Unlock()
			c.registeredNamespaces[subID] = make(map[string]bool)
			for _, p := range providers {
				c.registeredNamespaces[subID][strings.ToLower(*p.Namespace)] = true
			}

			return nil
		})

		errorGroup.Go(func() error {
			locations, err := c.getLocationsForSubscription(gtx, subID)
			if err != nil {
				return err
			}

			locationsLock.Lock()
			defer locationsLock.Unlock()
			c.SubscriptionsLocations[subID] = locations

			return nil
		})
	}
	return errorGroup.Wait()
}

func logAuthErrors(logger zerolog.Logger) {
	// NewDefaultAzureCredential builds a chain of credentials, and reports errors via the log listener
	// This is currently the way we have to get the errors and report them to the user
	// Any credential that has errors is ignored and the next one in the chain is tried when authenticating
	// So it's useful to report all the errors
	// It's logged as information as we don't know which credential chain the user intended to use
	log.SetEvents(azidentity.EventAuthentication)
	log.SetListener(func(e log.Event, s string) {
		if strings.HasPrefix(s, "NewDefaultAzureCredential failed") {
			logger.Info().Str("azure-sdk-for-go", "azidentity").Msg(s)
		}
	})
}

func logRetries(logger zerolog.Logger) {
	log.SetEvents(log.EventRetryPolicy)
	log.SetListener(func(e log.Event, s string) {
		logger.Debug().Str("azure-sdk-for-go", "azcore").Msg(s)
	})
}

func getCredential(logger zerolog.Logger, s *spec.Spec, options azcore.ClientOptions) (azcore.TokenCredential, error) {
	if s.OIDCToken != "" {
		tenantID := os.Getenv("AZURE_TENANT_ID")
		clientID := os.Getenv("AZURE_CLIENT_ID")
		if tenantID == "" {
			return nil, errors.New("AZURE_TENANT_ID is empty")
		}
		if clientID == "" {
			return nil, errors.New("AZURE_CLIENT_ID is empty")
		}

		oidcToken := s.OIDCToken
		return azidentity.NewClientAssertionCredential(tenantID, clientID, func(ctx context.Context) (string, error) {
			return oidcToken, nil
		}, nil)
	}

	if s.Auth != nil {
		return NewCustomCredential(logger, s.Auth.GetTokenCommand.Command, s.Auth.GetTokenCommand.Args...), nil
	}

	return azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ClientOptions: options})
}

func New(ctx context.Context, logger zerolog.Logger, s *spec.Spec) (schema.ClientMeta, error) {
	uniqueSubscriptions := funk.Uniq(s.Subscriptions).([]string)
	c := &Client{
		logger:             logger,
		subscriptions:      uniqueSubscriptions,
		Spec:               s,
		storageAccountKeys: &sync.Map{},
		Options:            &arm.ClientOptions{},
	}

	if s.CloudName != "" {
		cloudConfig, err := s.CloudConfig()
		if err != nil {
			return nil, err
		}
		c.Options.Cloud = cloudConfig
	}

	// fill in the retry settings
	s.RetryOptions.FillIn(&c.Options.Retry)

	logAuthErrors(logger)
	defer logRetries(logger)

	creds, err := getCredential(logger, s, c.Options.ClientOptions)
	if err != nil {
		return nil, err
	}
	c.Creds = creds

	return c, nil
}

func (c *Client) DiscoverMetadata(ctx context.Context, s *spec.Spec) error {
	// if subscription are not specified discover subscriptions with default credentials
	if len(c.subscriptions) == 0 {
		c.logger.Info().Msg("discovering subscriptions")
		if err := c.discoverSubscriptions(ctx); err != nil {
			return err
		}
	}
	// User specified subscriptions, that CloudQuery should skip syncing
	c.subscriptions = funk.LeftJoinString(c.subscriptions, s.SkipSubscriptions)

	if len(c.subscriptions) == 0 {
		return errors.New("no subscriptions found")
	}
	c.logger.Info().Msg("discovering resource groups")
	if err := c.discoverResourceGroups(ctx); err != nil {
		return err
	}
	c.logger.Info().Msg("discovering billing accounts")
	if err := c.discoverBillingAccounts(ctx); err != nil {
		c.logger.Warn().Err(err).Msg("failed to discover billing accounts (skipping)")
	}
	c.logger.Info().Msg("discovering billing periods")
	if err := c.discoverBillingPeriods(ctx); err != nil {
		c.logger.Warn().Err(err).Msg("failed to discover billing periods (skipping)")
	}

	return nil
}

func (c *Client) Logger(ctx context.Context) *zerolog.Logger {
	loggerContext := zerolog.Ctx(ctx).With()
	if c.SubscriptionId != "" {
		loggerContext = loggerContext.Str("subscription_id", c.SubscriptionId)
	}
	if c.ResourceGroup != "" {
		loggerContext = loggerContext.Str("resource_group", c.ResourceGroup)
	}
	if c.Location != "" {
		loggerContext = loggerContext.Str("location", c.Location)
	}
	if c.BillingAccount != nil {
		loggerContext = loggerContext.Str("billing_account", *c.BillingAccount.Name)
	}
	if c.BillingProfile != nil {
		loggerContext = loggerContext.Str("billing_profile", *c.BillingProfile.Name)
	}
	if c.BillingPeriod != nil {
		loggerContext = loggerContext.Str("billing_period", *c.BillingPeriod.Name)
	}
	if c.QuotaResourceProvider != "" {
		loggerContext = loggerContext.Str("quota_resource_provider", c.QuotaResourceProvider)
	}
	return lo.ToPtr(loggerContext.Logger())
}

func (c *Client) ID() string {
	if c.ResourceGroup != "" {
		return fmt.Sprintf("subscriptions/%s/resourceGroups/%s", c.SubscriptionId, c.ResourceGroup)
	}
	if c.BillingProfile != nil {
		return fmt.Sprintf("billingAccounts/%s/billingProfiles/%s", *c.BillingAccount.Name, *c.BillingProfile.Name)
	}
	if c.BillingAccount != nil {
		return fmt.Sprintf("billingAccounts/%s", *c.BillingAccount.Name)
	}
	if c.BillingPeriod != nil {
		return fmt.Sprintf("subscriptions/%s/billingPeriods/%s", c.SubscriptionId, *c.BillingPeriod.Name)
	}
	if c.QuotaResourceProvider != "" && c.Location != "" {
		return fmt.Sprintf("subscriptions/%s/providers/%s/locations/%s", c.SubscriptionId, c.QuotaResourceProvider, c.Location)
	}
	if c.Location != "" {
		return fmt.Sprintf("subscriptions/%s/locations/%s", c.SubscriptionId, c.Location)
	}
	return fmt.Sprintf("subscriptions/%s", c.SubscriptionId)
}

func (c *Client) Duplicate() *Client {
	newClient := *c
	return &newClient
}

// withSubscription allows multiplexer to create a new client with given subscriptionId
func (c *Client) withSubscription(subscriptionId string) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("subscription_id", subscriptionId).Logger()
	newC.SubscriptionId = subscriptionId
	return &newC
}

func (c *Client) withResourceGroup(resourceGroup string) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("resource_group", resourceGroup).Logger()
	newC.ResourceGroup = resourceGroup
	return &newC
}

func (c *Client) withLocation(location string) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("location", location).Logger()
	newC.Location = location
	return &newC
}

func (c *Client) withBillingAccount(billingAccount *armbilling.Account) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("billing_account", *billingAccount.ID).Logger()
	newC.BillingAccount = billingAccount
	return &newC
}

func (c *Client) withBillingProfile(billingProfile *armbilling.Profile) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("billing_profile", *billingProfile.ID).Logger()
	newC.BillingProfile = billingProfile
	return &newC
}

func (c *Client) withBillingPeriod(billingPeriod *armbilling.Period) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("billing_period", *billingPeriod.ID).Logger()
	newC.BillingPeriod = billingPeriod
	return &newC
}

func (c *Client) withQuotaResourceProvider(provider string) *Client {
	newC := *c
	newC.logger = c.logger.With().Str("quota_resource_provider", provider).Logger()
	newC.QuotaResourceProvider = provider
	return &newC
}

var ErrNoStorageKeysFound = errors.New("no storage keys found")

func (c *Client) GetStorageAccountKey(ctx context.Context, acc *armstorage.Account) (string, error) {
	key, err := loadOrStore(c.storageAccountKeys, *acc.Name, func() (any, error) {
		svc, err := armstorage.NewAccountsClient(c.SubscriptionId, c.Creds, c.Options)
		if err != nil {
			return nil, err
		}

		group, err := ParseResourceGroup(*acc.ID)
		if err != nil {
			return nil, err
		}

		keysResponse, err := svc.ListKeys(ctx, group, *acc.Name, nil)
		if err != nil {
			return nil, err
		}

		if len(keysResponse.Keys) == 0 {
			return nil, ErrNoStorageKeysFound
		}
		return *keysResponse.Keys[0].Value, nil
	})
	if key == nil {
		return "", err
	}
	return key.(string), err
}
