
#include <stdint.h>
#include <stdlib.h>
#include <libspe2.h>
#include <pthread.h>
#include <malloc.h>

#include "common.h"

/**
 * A small program to demonstrate DMA transfers to and from SPEs.
 *
 * 1) populate a spe_args structure with the address and size of our
 *    local buffer
 * 2) start an SPE context, giving it the address of its spe_args struct
 * 3) the SPE DMAs the spe_args struct to its local store
 * 4) using the arguments provided, the SPE DMAs some data back into our
 *    local buffer
 * 5) we wait for the SPE thread to finish.
 *
 * The SPE DMA engine is pretty fussy about byte-alignment, so we need to
 * be careful with the addresses we allocate for memory that we expect the
 * SPE to DMA from/to.
 */

extern spe_program_handle_t spe_memset;

struct spe_thread {
	spe_context_ptr_t ctx;
	pthread_t pthread;
	struct spe_args args __attribute__((aligned(SPE_ALIGN)));
};

/*
 * Simple helper function to write a block of data to a file
 */
void write_data(const char *filename, const uint8_t *data, int len)
{
	FILE *fp;

	fp = fopen(filename, "wb");
	if (!fp) {
		perror("fopen");
		return;
	}
	fwrite(data, 1, len, fp);
	fclose(fp);
}

void *spethread_fn(void *data)
{
	struct spe_thread *spethread = data;
	uint32_t entry = SPE_DEFAULT_ENTRY;

	/* run the context, passing the address of our args structure to
	 * the 'argv' argument to main() */
	spe_context_run(spethread->ctx, &entry, 0,
			&spethread->args, NULL, NULL);

	return NULL;
}

/*
 * Parse args (and do a few checks) into buf_size and n_threads
 */
int parse_args(int argc, char **argv, int *buf_size, int *n_threads)
{
	int tmp_buf_size, tmp_n_threads;
	char *endp;

	if (argc != 3) {
		fprintf(stderr, "Usage: %s <buf_size> <n_threads>\n", argv[0]);
		return -1;
	}

	tmp_buf_size = strtoul(argv[1], &endp, 0);
	if (endp == argv[1] || *endp != '\0') {
		fprintf(stderr, "invalid buf_size value: %s\n", argv[1]);
		return -1;
	}

	tmp_n_threads = strtoul(argv[2], &endp, 0);
	if (endp == argv[1] || *endp != '\0') {
		fprintf(stderr, "invalid n_threads value: %s\n", argv[2]);
		return -1;
	}

	if (tmp_n_threads < 0 || tmp_n_threads > 8) {
		fprintf(stderr, "n_threads must be between 0 and 8\n");
		return -1;
	}

	if ((tmp_buf_size % 16)) {
		fprintf(stderr, "buf_size must be a multiple of 16\n");
		return -1;
	}

	if ((tmp_buf_size % tmp_n_threads)) {
		fprintf(stderr, "buf_size must be a multiple of n_threads\n");
		return -1;
	}

	if ((tmp_buf_size / tmp_n_threads) % 16) {
		fprintf(stderr, "buf_size / n_threads must be a multiple "
				"of 16\n");
		return -1;
	}

	*buf_size = tmp_buf_size;
	*n_threads = tmp_n_threads;
	return 0;
}

int main(int argc, char **argv)
{
	struct spe_thread *threads;
	int n_threads, buf_size, size_per_spe, i;
	uint8_t *buf;

	if (parse_args(argc, argv, &buf_size, &n_threads))
		return -1;

	/* Allocate our local buffer, and argument struct, to sit on a
	 * SPE_ALIGN-byte boundary */
	buf = memalign(SPE_ALIGN, buf_size);
	threads = memalign(SPE_ALIGN, n_threads * sizeof(*threads));

	size_per_spe = buf_size / n_threads;

	for (i = 0; i < n_threads; i++) {
		/* The offset within our main buffer for this SPE to copy
		 * its data into */
		uint8_t *c = buf + (i * size_per_spe);

		/* Set up the arguments passed to the SPE */
		threads[i].args.buf_addr = (uint64_t)(unsigned long)c;
		threads[i].args.buf_size = size_per_spe;
		threads[i].args.c = i;

		threads[i].ctx = spe_context_create(0, NULL);

		spe_program_load(threads[i].ctx, &spe_memset);

		/* run the context in a new thread, passing the spe_thread
		 * struct as the first argument of spethread_fn */
		pthread_create(&threads[i].pthread, NULL,
				spethread_fn, &threads[i]);
	}

	/* wait for all of the threads to finish */
	for (i = 0; i < n_threads; i++)
		pthread_join(threads[i].pthread, NULL);

	/* write the resulting data out to a file */
	write_data("out.data", buf, buf_size);

	return 0;
}
