Logo Search packages:      
Sourcecode: p4fftwgel version File versions  Download package

fftw_mpi.c

/*
 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */

#include <stdio.h>
#include <math.h>

#include <fftw_mpi.h>
#include <fftw-int.h>

/************************** Twiddle Factors *****************************/

/* To conserve space, we share twiddle factor arrays between forward and
   backward plans and plans of the same size (just as in the uniprocessor
   transforms). */

static fftw_mpi_twiddle *fftw_mpi_twiddles = NULL;

static fftw_mpi_twiddle *fftw_mpi_create_twiddle(int rows, int rowstart,
                                     int cols, int n)
{
     fftw_mpi_twiddle *tw = fftw_mpi_twiddles;

     while (tw && (tw->rows != rows || tw->rowstart != rowstart ||
               tw->cols != cols || tw->n != n))
        tw = tw->next;

     if (tw) {
        tw->refcount++;
        return tw;
     }

     tw = (fftw_mpi_twiddle *) fftw_malloc(sizeof(fftw_mpi_twiddle));
     tw->rows = rows;
     tw->rowstart = rowstart;
     tw->cols = cols;
     tw->n = n;
     tw->refcount = 1;
     tw->next = fftw_mpi_twiddles;

     {
        fftw_complex *W = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
                                           rows * (cols - 1));
        int j, i;
        FFTW_TRIG_REAL twoPiOverN = FFTW_K2PI / (FFTW_TRIG_REAL) n;

        for (j = 0; j < rows; ++j)
             for (i = 1; i < cols; ++i) {
                int k = (j * (cols - 1) - 1) + i;
                FFTW_TRIG_REAL
                   ij = (FFTW_TRIG_REAL) (i * (j + rowstart));
                c_re(W[k]) = FFTW_TRIG_COS(twoPiOverN * ij);
                c_im(W[k]) = FFTW_FORWARD * FFTW_TRIG_SIN(twoPiOverN * ij);
             }
        
        tw->W = W;
     }

     fftw_mpi_twiddles = tw;

     return tw;
}

static void fftw_mpi_destroy_twiddle(fftw_mpi_twiddle *tw)
{
     if (tw) {
        tw->refcount--;
        if (tw->refcount == 0) {
             /* delete tw from fftw_mpi_twiddles list: */
             if (fftw_mpi_twiddles == tw)
                fftw_mpi_twiddles = tw->next;
             else {
                fftw_mpi_twiddle *prev = fftw_mpi_twiddles;
                
                if (!prev)
                   fftw_mpi_die("unexpected empty MPI twiddle list");
                while (prev->next && prev->next != tw)
                   prev = prev->next;
                if (prev->next != tw)
                   fftw_mpi_die("tried to destroy unknown MPI twiddle");
                prev->next = tw->next;
             }
             
             fftw_free(tw->W);
             fftw_free(tw);
        }
     }
}

/* multiply the array in d (of size tw->cols * n_fields) by the row cur_row
   of the twiddle factors pointed to by tw, given the transform direction. */
static void fftw_mpi_mult_twiddles(fftw_complex *d, int n_fields,
                           int cur_row,
                           fftw_mpi_twiddle *tw,
                           fftw_direction dir)
{
     int cols = tw->cols;
     fftw_complex *W = tw->W + cur_row * (cols - 1);
     int j;

     if (dir == FFTW_FORWARD) {
          if (n_fields > 1)
               for (j = 1; j < cols; ++j) {
                    fftw_real
                         w_re = c_re(W[j-1]),
                         w_im = c_im(W[j-1]);
                    int f;

                    for (f = 0; f < n_fields; ++f) {
                         fftw_real
                              d_re = c_re(d[j*n_fields + f]),
                              d_im = c_im(d[j*n_fields + f]);
                         c_re(d[j*n_fields + f]) = w_re * d_re - w_im * d_im;
                         c_im(d[j*n_fields + f]) = w_re * d_im + w_im * d_re;
                    }
               }
          else
               for (j = 1; j < cols; ++j) {
                    fftw_real w_re = c_re(W[j-1]),
                         w_im = c_im(W[j-1]),
                         d_re = c_re(d[j]),
                         d_im = c_im(d[j]);
                    c_re(d[j]) = w_re * d_re - w_im * d_im;
                    c_im(d[j]) = w_re * d_im + w_im * d_re;
               }
     }
     else {  /* FFTW_BACKWARDS */
        /* same as above, except that W is complex-conjugated: */
          if (n_fields > 1)
               for (j = 1; j < cols; ++j) {
                    fftw_real
                         w_re = c_re(W[j-1]),
                         w_im = c_im(W[j-1]);
                    int f;

                    for (f = 0; f < n_fields; ++f) {
                         fftw_real
                              d_re = c_re(d[j*n_fields + f]),
                              d_im = c_im(d[j*n_fields + f]);
                         c_re(d[j*n_fields + f]) = w_re * d_re + w_im * d_im;
                         c_im(d[j*n_fields + f]) = w_re * d_im - w_im * d_re;
                    }
               }
          else
               for (j = 1; j < cols; ++j) {
                    fftw_real w_re = c_re(W[j-1]),
                         w_im = c_im(W[j-1]),
                         d_re = c_re(d[j]),
                         d_im = c_im(d[j]);
                    c_re(d[j]) = w_re * d_re + w_im * d_im;
                    c_im(d[j]) = w_re * d_im - w_im * d_re;
               }
     }
}

/***************************** Plan Creation ****************************/

/* return the factor of n closest to sqrt(n): */
static int find_sqrt_factor(int n)
{
     int i = sqrt(n) + 0.5;
     int i2 = i - 1;
     
     while (i2 > 0) {
        if (n % i2 == 0)
             return i2;
        if (n % i == 0)
             return i;
        ++i; --i2;
     }
     return 1; /* n <= 1 */
}

/* find the "best" r to divide n by for the FFT decomposition.  Ideally,
   we would like both r and n/r to be divisible by the number of 
   processes (for optimum load-balancing).  Also, pick r to be close
   to sqrt(n) if possible. */
static int find_best_r(int n, MPI_Comm comm)
{
     int n_pes;

     MPI_Comm_size(comm, &n_pes);

     if (n % n_pes == 0) {
        n /= n_pes;
        if (n % n_pes == 0)
             return (n_pes * find_sqrt_factor(n / n_pes));
        else
             return (n_pes * find_sqrt_factor(n));
     }
     else
        return find_sqrt_factor(n);
}

#define MAX2(a,b) ((a) > (b) ? (a) : (b))

fftw_mpi_plan fftw_mpi_create_plan(MPI_Comm comm,
                           int n, fftw_direction dir, int flags)
{
     fftw_mpi_plan p;
     int i, r, m;

     p = (fftw_mpi_plan) fftw_malloc(sizeof(struct fftw_mpi_plan_struct));

     i = find_best_r(n, comm);
     if (dir == FFTW_FORWARD)
        m = n / (r = i);
     else
        r = n / (m = i);

     p->n = n;
     p->r = r;
     p->m = m;

     flags |= FFTW_IN_PLACE;
     p->flags = flags;
     p->dir = dir;

     p->pr = fftw_create_plan(r, dir, flags);
     p->pm = fftw_create_plan(m, dir, flags);
        
     p->p_transpose = transpose_mpi_create_plan(m, r, comm);
     p->p_transpose_inv = transpose_mpi_create_plan(r, m, comm);

     transpose_mpi_get_local_size(r,
                          p->p_transpose_inv->my_pe,
                          p->p_transpose_inv->n_pes,
                          &p->local_r,
                          &p->local_r_start);
     transpose_mpi_get_local_size(m,
                          p->p_transpose->my_pe,
                          p->p_transpose->n_pes,
                          &p->local_m,
                          &p->local_m_start);

     if (dir == FFTW_FORWARD)
        p->tw = fftw_mpi_create_twiddle(p->local_r, p->local_r_start, m, n);
     else
        p->tw = fftw_mpi_create_twiddle(p->local_m, p->local_m_start, r, n);

     p->fft_work = (fftw_complex *) fftw_malloc(sizeof(fftw_complex) *
                                    MAX2(m, r));

     return p;
}

/********************* Getting Local Size ***********************/

void fftw_mpi_local_sizes(fftw_mpi_plan p,
                    int *local_n,
                    int *local_start,
                    int *local_n_after_transform,
                    int *local_start_after_transform,
                    int *total_local_size)
{
     if (p) {
        if (p->flags & FFTW_SCRAMBLED_INPUT) {
             *local_n = p->local_r * p->m;
             *local_start = p->local_r_start * p->m;
        }
        else {
             *local_n = p->local_m * p->r;
             *local_start = p->local_m_start * p->r;
        }
        
        if (p->flags & FFTW_SCRAMBLED_OUTPUT) {
             *local_n_after_transform = p->local_m * p->r;
             *local_start_after_transform = p->local_m_start * p->r;
        }
        else {
             *local_n_after_transform = p->local_r * p->m;
             *local_start_after_transform = p->local_r_start * p->m;
        }

        *total_local_size =
             transpose_mpi_get_local_storage_size(p->p_transpose->nx,
                                        p->p_transpose->ny,
                                        p->p_transpose->my_pe,
                                        p->p_transpose->n_pes);
     }
}

static void fftw_mpi_fprint_plan(FILE *f, fftw_mpi_plan p)
{
     fprintf(f, "mpi plan:\n");
     fprintf(f, "m = %d plan:\n", p->m);
     fftw_fprint_plan(f, p->pm);
     fprintf(f, "r = %d plan:\n", p->r);
     fftw_fprint_plan(f, p->pr);
}

void fftw_mpi_print_plan(fftw_mpi_plan p)
{
     fftw_mpi_fprint_plan(stdout, p);
}

/********************** Plan Destruction ************************/

void fftw_mpi_destroy_plan(fftw_mpi_plan p)
{
     if (p) {
        fftw_destroy_plan(p->pr);
        fftw_destroy_plan(p->pm);
        transpose_mpi_destroy_plan(p->p_transpose);
        transpose_mpi_destroy_plan(p->p_transpose_inv);
        fftw_mpi_destroy_twiddle(p->tw);
        fftw_free(p->fft_work);
        fftw_free(p);
     }
}

/******************** Computing the Transform *******************/

void fftw_mpi(fftw_mpi_plan p, int n_fields,
            fftw_complex *local_data, fftw_complex *work)
{
     int i;
     int el_size = (sizeof(fftw_complex) / sizeof(TRANSPOSE_EL_TYPE))
                   * n_fields;
     fftw_complex *fft_work;
     fftw_direction dir;
     fftw_mpi_twiddle *tw;

     if (n_fields < 1)
        return;

     if (!(p->flags & FFTW_SCRAMBLED_INPUT))
        transpose_mpi(p->p_transpose, el_size,
                  (TRANSPOSE_EL_TYPE *) local_data,
                  (TRANSPOSE_EL_TYPE *) work);

     tw = p->tw;
     dir = p->dir;
     fft_work = work ? work : p->fft_work;

     /* For forward plans, we multiply by the twiddle factors here,
      before the second transpose.  For backward plans, we multiply
      by the twiddle factors after the second transpose.  We do
      this so that forward and backward transforms can share the
      same twiddle factor array (noting that m and r are swapped
      for the two directions so that the local sizes will be compatible). */

     {
        int rows = p->local_r, cols = p->m;
        fftw_plan p_fft = p->pm;

        if (dir == FFTW_FORWARD) {
             for (i = 0; i < rows; ++i) {
                fftw_complex *d = local_data + i * (cols * n_fields);
                
                fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
                fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_FORWARD);
             }
        }
        else {
             if (n_fields > 1)
                for (i = 0; i < rows; ++i)
                   fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
                        n_fields, 1, fft_work, 1, 0);
             else
                fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
        }
     }

     transpose_mpi(p->p_transpose_inv, el_size,
               (TRANSPOSE_EL_TYPE *) local_data,
               (TRANSPOSE_EL_TYPE *) work);

     {
        int rows = p->local_m, cols = p->r;
        fftw_plan p_fft = p->pr;

        if (dir == FFTW_BACKWARD) {
             for (i = 0; i < rows; ++i) {
                fftw_complex *d = local_data + i * (cols * n_fields);
                
                fftw_mpi_mult_twiddles(d, n_fields, i, tw, FFTW_BACKWARD);
                fftw(p_fft, n_fields, d, n_fields, 1, fft_work, 1, 0);
             }
        }
        else {
             if (n_fields > 1)
                for (i = 0; i < rows; ++i)
                   fftw(p_fft, n_fields, local_data + i*(cols*n_fields),
                        n_fields, 1, fft_work, 1, 0);
             else
                fftw(p_fft, rows, local_data, 1, cols, fft_work, 1, 0);
        }
     }

     if (!(p->flags & FFTW_SCRAMBLED_OUTPUT))
        transpose_mpi(p->p_transpose, el_size,
                  (TRANSPOSE_EL_TYPE *) local_data,
                  (TRANSPOSE_EL_TYPE *) work);

     /* Yes, we really had to do three transposes...sigh. */
}



Generated by  Doxygen 1.6.0   Back to index