// Copyright CloudQuery Authors
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.

package client

import (
	"context"
	"errors"
	"fmt"

	"cloud.google.com/go/firestore"
	"github.com/apache/arrow-go/v18/arrow/array"
	"github.com/apache/arrow-go/v18/arrow/memory"
	"github.com/cloudquery/plugin-sdk/v4/message"
	"github.com/cloudquery/plugin-sdk/v4/plugin"
	"github.com/cloudquery/plugin-sdk/v4/premium"
	"github.com/cloudquery/plugin-sdk/v4/schema"
	"github.com/cloudquery/plugin-sdk/v4/types"
	"golang.org/x/sync/errgroup"
	"google.golang.org/api/iterator"
)

func (c Client) Sync(ctx context.Context, options plugin.SyncOptions, res chan<- message.SyncMessage) error {
	if c.options.NoConnection {
		return errors.New("no connection")
	}
	filtered, err := c.tables.FilterDfs(options.Tables, options.SkipTables, options.SkipDependentTables)
	if err != nil {
		return err
	}
	for _, table := range filtered {
		res <- &message.SyncMigrateTable{
			Table: table,
		}
	}
	ctx, err = premium.WithCancelOnQuotaExceeded(ctx, c.usage)
	if err != nil {
		return fmt.Errorf("failed to configure quota monitor: %w", err)
	}
	return c.syncTables(ctx, filtered, res)
}

// OnBeforeSend increases the usage count for every message. If some messages should not be counted,
// they can be ignored here.
func (c *Client) OnBeforeSend(_ context.Context, msg message.SyncMessage) (message.SyncMessage, error) {
	if si, ok := msg.(*message.SyncInsert); ok {
		if err := c.usage.Increase(uint32(si.Record.NumRows())); err != nil {
			return msg, fmt.Errorf("failed to increase usage: %w", err)
		}
	}
	return msg, nil
}

// OnSyncFinish is used to ensure the final usage count gets reported
func (c *Client) OnSyncFinish(_ context.Context) error {
	if c.usage != nil {
		return c.usage.Close()
	}
	return nil
}

func (c *Client) syncTable(ctx context.Context, table *schema.Table, res chan<- message.SyncMessage) error {
	var err error
	lastDocumentId := ""
	maxBatchSize := c.spec.MaxBatchSize
	collection := c.client.Collection(table.Name)
	orderBy := firestore.DocumentID
	if c.spec.OrderBy != "" {
		orderBy = c.spec.OrderBy
	}

	dir := firestore.Asc
	if c.spec.OrderDirection == "desc" {
		dir = firestore.Desc
	}

	query := collection.Query.
		OrderBy(orderBy, dir).
		Limit(maxBatchSize)

	arrowSchema := table.ToArrowSchema()
	builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema)
	idField := builder.Field(0).(*array.StringBuilder)
	createdAtField := builder.Field(1).(*array.TimestampBuilder)
	updatedAtField := builder.Field(2).(*array.TimestampBuilder)
	dataField := builder.Field(3).(*types.JSONBuilder)
	rowsInRecord := 0

	for {
		if lastDocumentId != "" {
			c.logger.Info().Msgf("Starting after %s", lastDocumentId)
			query = query.StartAfter(lastDocumentId)
		}
		docIter := query.Documents(ctx)

		var documentCount int
		var skippedCount int
		for {
			docSnap, err := docIter.Next()
			if err != nil {
				if errors.Is(err, iterator.Done) {
					break
				}
				return err
			}
			documentCount++
			if !docSnap.Exists() {
				skippedCount++
				continue
			}
			lastDocumentId = docSnap.Ref.ID

			idField.AppendString(docSnap.Ref.ID)
			createdAtField.AppendTime(docSnap.CreateTime)
			updatedAtField.AppendTime(docSnap.UpdateTime)
			dataField.Append(docSnap.Data())
			rowsInRecord++
			if rowsInRecord >= c.spec.RowsPerRecord {
				res <- &message.SyncInsert{Record: builder.NewRecord()} // NewRecord resets the builder for reuse
				rowsInRecord = 0
			}
		}
		c.logger.Info().Msgf("Synced %d documents from %s", documentCount, table.Name)
		if skippedCount > 0 {
			c.logger.Info().Msgf("Skipped %d documents from %s", skippedCount, table.Name)
		}
		if documentCount < maxBatchSize {
			break
		}
	}

	if rowsInRecord > 0 { // only send if there are some unsent rows
		res <- &message.SyncInsert{Record: builder.NewRecord()}
	}
	return err
}

func (c *Client) syncTables(ctx context.Context, tables schema.Tables, res chan<- message.SyncMessage) error {
	eg, gctx := errgroup.WithContext(ctx)
	for _, table := range tables {
		t := table
		eg.Go(func() error {
			return c.syncTable(gctx, t, res)
		})
	}
	return eg.Wait()
}
