/*
 * Copyright (C) 2002 Bartosz Lis <bartoszl@ics.p.lodz.pl>
 * This module is used to set password
 */

#define _GNU_SOURCE
#include "config.h"

#include <security/_pam_types.h>
#include <security/pam_appl.h>
#include <security/pam_misc.h>
#include <string.h>
#include <stdlib.h>

#include "setpass.h"

struct prompt_t
{
  struct prompt_t *next;
  int              stage;
  char             str[1];
};

static struct prompt_t *prompts=0;

static void
canonify(const char **str, size_t *len)
{
   size_t l_ws=0, l_nws;
   *str+=strspn(*str,WS);
   *len=0;
   while (l_nws=strcspn(*str+*len+l_ws,WSCRNL))
     {
        *len+=l_nws+l_ws;
        l_ws=strspn(*str,WSCRNL);
     }
}

static void
new_prompt(struct prompt_t **prompts, int stage, const char *prompt, size_t len)
{
   struct prompt_t *n_prompt;
   n_prompt=(struct prompt_t *)malloc(sizeof(struct prompt_t)+len);
   if (!n_prompt) error_msg(stderr,"memory too short",0,2);
   strncpy(n_prompt->str,prompt,len);
   n_prompt->str[len]=0;
   n_prompt->stage=stage;
   n_prompt->next=*prompts;
   *prompts=n_prompt;
}

static int
which_stage(struct prompt_t *prompts, const char *prompt)
{
   struct prompt_t *p_prompt;
   size_t           len;
   canonify(&prompt,&len);
   for (p_prompt=prompts; p_prompt; p_prompt=p_prompt->next)
     if (!strncmp(prompt,p_prompt->str,len)) return p_prompt->stage;
   return 0;
}

#define BUFF_SIZE 1025

void
load_prompts(const char *file)
{
   FILE       *f;
   char        buffer[BUFF_SIZE], *str;
   const char *prompt;
   int         ret, l=0, len;
   size_t      p_len;
   if (!(f=fopen(file,"r"))) 
     {
        sprintf(buffer,"cannot open %s",file);
        error_msg(stderr,buffer,0,4);
     }
   while (fgets(buffer,BUFF_SIZE,f)) 
     {
        ++l;
        if ((strlen(buffer)==BUFF_SIZE-1) && buffer[BUFF_SIZE-2]!='\n')
        {
           fclose(f);
           sprintf(buffer,"line %d in %s too long",l,file);
           error_msg(stderr,buffer,0,4);
        }
        switch (*(str=buffer+strspn(buffer,WS)))
          {
           case '#':
           case '\r':
           case '\n':
           case 0:
             break;
           default:
             prompt=str+(len=strcspn(str,WSCRNL));
             canonify(&prompt,&p_len);
             if (p_len)
               {
                  if (!strncmp(str,SETPASS_STAGE_OLD_TAG,len))
                    new_prompt(&prompts,SETPASS_STAGE_OLD,prompt,p_len);
                  else if (!strncmp(str,SETPASS_STAGE_NEW_TAG,len))
                    new_prompt(&prompts,SETPASS_STAGE_NEW,prompt,p_len);
                  else if (!strncmp(str,SETPASS_STAGE_NEW_AGAIN_TAG,len))
                    new_prompt(&prompts,SETPASS_STAGE_NEW_AGAIN,prompt,p_len);
                  else p_len=0;
               }
             if (!p_len)
               {
                  fclose(f);
                  sprintf(buffer,"cannot parse line %d in %s",l,file);
                  error_msg(stderr,buffer,0,4);
               }
          }
     }
   ret=ferror(f);
   fclose(f);
   if (ret)
     {
        sprintf(buffer,"error reading %s",file);
        error_msg(stderr,buffer,0,4);
     }
}

void
free_prompts(void)
{
   struct prompt_t *prompt;
   while (prompts)
     {
        prompts=(prompt=prompts)->next;
        free(prompt);
     }
}

struct conv_data_t
{
  ask_pass_t       ask_pass;
  struct prompt_t *prompts;
  int              stage;
  void            *data;
};

static int
conversation(int num_msg, const struct pam_message **msg,
             struct pam_response **resp, void *appdata_ptr)
{
   int                 ret, i, stage;
   struct conv_data_t *conv_data=(struct conv_data_t *)appdata_ptr;
   char               *answer;
   if (!appdata_ptr || !num_msg) return PAM_CONV_ERR;
   if (!(*resp=(struct pam_response *)calloc(num_msg,
                                             sizeof(struct pam_response))))
     return PAM_CONV_ERR;
   ret=PAM_SUCCESS;
   for (i=0; i<num_msg; ++i)
     {
        (*resp)[i].resp=NULL;
        (*resp)[i].resp_retcode=0;
     }
   for (i=0; (i<num_msg) && (ret==PAM_SUCCESS); ++i) 
     switch (msg[i]->msg_style)
       {
        case PAM_ERROR_MSG:
          error_msg(stderr,"error message got from PAM:",
                    0,-1);
          fprintf (stderr,"  %s\n",msg[i]->msg);
          ret=PAM_CONV_ERR;
          conv_data->stage=0;
          break;
        case PAM_TEXT_INFO:
          fprintf (stdout,"  %s\n",msg[i]->msg);
          break;
        default:
          if (!(stage=which_stage(conv_data->prompts,msg[i]->msg)))
            {
               if (conv_data->stage) 
                 error_msg(stderr,"unregistered password prompts got from PAM:",
                           0,-1);
               fprintf (stderr,"  %s\n",msg[i]->msg);
               ret=PAM_CONV_ERR;
               conv_data->stage=0;
            }
          else if ((stage<conv_data->stage) || !(answer=
                   (*conv_data->ask_pass)(conv_data->data,msg[i]->msg,stage)))
            ret=PAM_CONV_ERR;
          else if (ret==PAM_SUCCESS)
            {
              (*resp)[i].resp=(char *)x_strdup(answer);
              _pam_overwrite(answer);
              conv_data->stage=stage+1;
            }
       }
   return ret;
}

int
set_pass(void *data, const char *user, ask_pass_t ask_pass, int flags)
{
  int                  ret;
  pam_handle_t        *handle;
  struct pam_conv      conv;
  struct conv_data_t   conv_data;
  conv_data.ask_pass=ask_pass;
  conv_data.prompts=prompts;
  conv_data.stage=SETPASS_STAGE_OLD;
  conv_data.data=data;
  conv.conv=&conversation;
  conv.appdata_ptr=&conv_data;
  ret=pam_start("passwd"/*PAM_SERVICE_NAME*/,user,&conv,&handle);
  if (ret==PAM_SUCCESS) ret=pam_chauthtok(handle,flags);
  pam_end(handle,ret);
  return ret;
}

