/*
 comdefs.cpp

 XIDEK: Extensible Script Language Development Kit
 Common code.
 Modified to support complex arithmetic.

 Copyright (c) 1996-2002 Parsifal Software.
 All Rights Reserved.

 For further information about this program or the AnaGram
 parser generator, please contact:

    Parsifal Software
    http://www.parsifalsoft.com
    info@parsifalsoft.com
    +1-800-879-2577, Voice/Fax +1-508-358-2564
    P.O. Box 219
    Wayland, MA 01778
    USA
*/

#include "comdefs.h"

#include <math.h>
#include <stdio.h>

static const char *argRangeMsg      = "Argument out of range";
static const char *hashUndefinedMsg = "Hash is not defined";
static const char *notIntegerMsg    = "Operand is not an integer";
static const char *notLvalueMsg     = "Not an lvalue";
static const char *notPointerMsg    = "Operand is not a pointer";
static const char *notScalarMsg     = "Operand is not scalar";
static const char *notStringMsg     = "Operand is not a string";
static const char *notSupportedMsg  = "Function not supported";
static const char *opUndefinedMsg   = "Operation not defined";
static const char *uninitMsg        = "Operand is not initialized";
static const char *uninitValueMsg   = "Unitialized value";
static const char *wrongCountMsg    = "Incorrect number of arguments";
static const char *zeroDivisorMsg   = "Divide by zero";
static const char *notComplexMsg    = "Operand is not complex";
static const char *notRealMsg       = "Operand is not a real value";

ErrorMessage::ErrorMessage(const AgString &m)
  : msg(m)
{
  // Nothing to do
}
const char *ErrorMessage::message() const {
  return (const char *) msg;
}

ErrorDiagnostic::ErrorDiagnostic(const AgString &m)
  : msg(m)
{
  // Nothing to do
}

const char *ErrorDiagnostic::message() const {
  return (const char *) msg;
}

/*** Value class ************************************************************/

// Constructors

Value::Value() : type(uninitType), value(0) {part.imag = 0;}

Value::Value(int x) : type(integerType), value(x) {part.imag = 0;}

Value::Value(long x) : type(integerType), value(x) {part.imag = 0;}

Value::Value(double x, Type t) : type(t), value(x) {
  part.imag = 0;
  assertComplex();
}

Value::Value(const AgString &s) : type(stringType), value(0) {
  new(&value) ValueWrapper<AgString>(s);
}

Value::Value(Value *p) : type(pointerType), pointer(p) {}

Value::Value(double x, double y) : type(complexType), value(x) {part.imag = y;}

Value::Value(const Value &v) : type(v.type), value(0) {
  setValue(v);
}

Value &Value::setValue(const Value &v) {
  part.imag = 0;
  type = v.type;
  switch (type) {
    case complexType: part.imag = v.part.imag;
    case integerType:
    case realType: value = v.value; break;
    case stringType: {
      new(&value) ValueWrapper<AgString>(v.getString());
      break;
    }
    case pointerType: pointer = v.pointer; break;
  }
  return *this;
}

// Value class destructor
Value::~Value() {
  switch (type) {
    case stringType: {
      ValueWrapper<AgString> *wrapper =(ValueWrapper<AgString> *) &value;
      delete wrapper;
      break;
    }
  }
}

// Value class hash function
// Only used for constant values

int Value::hash(int hashcode) const {
  switch (type) {
    default: throw ErrorMessage(hashUndefinedMsg);
    case integerType:
    case realType: hashcode = agABTHash(value, hashcode); break;
    case stringType: hashcode =  agABTHash( (AgString) *((ValueWrapper<AgString> *) &value), hashcode);
      break;
  }
  return hashcode;
}

// Error checking
void Value::assertInitialized() const {
  if (type != uninitType) return;
  throw ErrorMessage(uninitMsg);
}

void Value::assertScalar() const {
  if (type == integerType || type == realType) return;
  throw ErrorMessage(notScalarMsg);
}

void Value::assertString() const {
  if (type == stringType) return;
  throw ErrorMessage(notStringMsg);
}

void Value::assertPointer() const {
  if (type == pointerType) return;
  throw ErrorMessage(notPointerMsg);
}

void Value::assertInteger() const {
  if (type == integerType) return;
  throw ErrorMessage(notIntegerMsg);
}

Value &Value::deref() {
  switch (type) {
    default: throw ErrorMessage(notLvalueMsg);
    case pointerType: {
      setValue(*pointer);
      break;
    }
  }
  return *this;
}

int Value::isTrue() const {
  int returnValue;
  switch (type) {
    default: throw ErrorMessage(opUndefinedMsg);
    case stringType: returnValue = (getString().size() != 0); break;
    case integerType:
    case realType: returnValue = (value != 0); break;
    case complexType: returnValue = (part.real != 0 || part.imag != 0);
      break;
  }
  return returnValue;
}

// Type conversions
Value &Value::makeInteger() {
  assertScalar();
  type = integerType;
  value = (long) value;
  return *this;
}

Value &Value::makeReal() {
  assertScalar();
  type = realType;
  return *this;
}

AgString Value::asString() const {
  char buf[40];
  switch (type) {
    case pointerType: return pointer->asString();
    case integerType: sprintf(buf, "%ld", (long) value); break;
    case realType: sprintf(buf, "%G", value); break;
    case stringType: return getString();
    case complexType: return complexToString();
    default: return AgString(uninitValueMsg);
  }
  return AgString(buf);
}

AgString Value::asLiteral() const {
  if (type != stringType) return asString();
  const char *s = getString();
  AgString newString;
  newString.append('"');
  while (*s) {
    unsigned char c = *s++;
    switch (c) {
      case '\'':   newString.append("\\\'"); break;
      case '\"':   newString.append("\\\""); break;
      case '\?':   newString.append("\\?");  break;
      case '\\':   newString.append("\\\\"); break;
      case '\a':   newString.append("\\a");  break;
      case '\b':   newString.append("\\b");  break;
      case '\f':   newString.append("\\f");  break;
      case '\n':   newString.append("\\n");  break;
      case '\r':   newString.append("\\r");  break;
      case '\t':   newString.append("\\t");  break;
      case '\v':   newString.append("\\v");  break;
      default:     newString.append(c);      break;
    }
  }
  newString.append('"');
  return newString;
}

AgString Value::complexToString() const {
  char buf[40];
  char *p = buf;
  assertComplex();
  if (part.real) {
    p += sprintf(p, "%G", part.real);
    if (part.imag > 0) sprintf(p, " + %G*i", part.imag);
    else if (part.imag < 0) sprintf(p, " - %G*i", -part.imag);
  }
  else sprintf(p, "%G*i", part.imag);
  return AgString(buf);
}

Value Value::rdiv(const Value &divisor) {
  Value v(*this);
  if (v.type == integerType) v.type = realType;
  return v /= divisor;
}

Value Value::idiv(const Value &divisor) {
  assertInteger();
  divisor.assertInteger();
  Value v(*this);
  return v /= divisor;
}

// Operator overloads
const Value &Value::operator = (const Value &v) {
  // check destination type
  switch (type) {
    case pointerType: return *pointer = v;
    case stringType: {
      ValueWrapper<AgString> *wrapper =(ValueWrapper<AgString> *) &value;
      delete wrapper;
      break;
    }
  }
  setValue(v);
  return *this;
}
const Value &Value::operator += (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(opUndefinedMsg);
    case pointerType: return *pointer += v;
    case stringType: getString().append(v.asString()); break;
    case integerType:
    case realType:
    case complexType: {
      v.assertComplex();
      if (type < v.type) type = v.type;
      part.real += v.part.real;
      part.imag += v.part.imag;
      break;
    }
  }
  return *this;
}

const Value &Value::operator -= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notScalarMsg);
    case pointerType: return *pointer -= v;
    case integerType:
    case realType:
    case complexType:
      v.assertComplex();
      if (type < v.type) type = v.type;
      part.real -= v.part.real;
      part.imag -= v.part.imag;
      break;
  }
  return *this;
}
const Value &Value::operator *= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notScalarMsg);
    case pointerType: return *pointer *= v;
    case integerType:
    case realType:
    case complexType:
      v.assertComplex();
      if (type < v.type) type = v.type;
      double x = part.real * v.part.real - part.imag * v.part.imag;
      double y = part.real * v.part.imag + part.imag * v.part.real;
      part.real = x;
      part.imag = y;
      break;
  }
  return *this;
}

const Value &Value::operator /= (const Value &v) {
  switch (type) {
    case pointerType: return *pointer /= v;
    case integerType: if (v.type == integerType) {
      if (v.value == 0) throw ErrorMessage(zeroDivisorMsg);
      value = (long) value / (long) v.value;
      return *this;
    }
  }
  assertComplex();
  v.assertComplex();
  if (type < v.type) type = v.type;
  double ms = v.magnitudeSquared();
  if (ms == 0) throw ErrorMessage(zeroDivisorMsg);
  double x = (part.real * v.part.real + part.imag * v.part.imag)/ms;
  double y = (-part.real * v.part.imag + part.imag * v.part.real)/ms;
  part.real = x;
  part.imag = y;
  return *this;
}

// Integer operations
const Value &Value::operator %= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer %= v;
    case integerType:
      long k = v.getLong();
      if (k == 0) throw ErrorMessage(zeroDivisorMsg);
      value = (long) value % k;
      break;
  }
  return *this;
}

const Value &Value::operator &= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer &= v;
    case integerType: value = (long) value & v.getLong();
      break;
  }
  return *this;
}

const Value &Value::operator |= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer |= v;
    case integerType: value = (long) value | v.getLong();
      break;
  }
  return *this;
}

const Value &Value::operator ^= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer ^= v;
    case integerType: value = (long) value ^ v.getLong();
      break;
  }
  return *this;
}

const Value &Value::operator <<= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer <<= v;
    case integerType: value = (long) value << v.getLong(); break;
  }
  return *this;
}

const Value &Value::operator >>= (const Value &v) {
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return *pointer >>= v;
    case integerType: value = (long) value >> v.getLong(); break;
  }
  return *this;
}

int Value::operator < (const Value &v) const {
  int returnValue;
  switch (type) {
    default: throw ErrorMessage(opUndefinedMsg);
    case complexType: if (part.imag != 0) throw ErrorMessage(notRealMsg);
    case integerType:
    case realType: {
      v.assertScalar();
      returnValue = (value < v.value);
      break;
    }
    case stringType: {
      v.assertString();
      returnValue = (getString() < v.getString());
      break;
    }
  }
  return returnValue;
}

int Value::operator == (const Value &v) const {
  switch (type) {
    case integerType:
    case realType: {
      if (v.type == realType || v.type == integerType) return value == v.value;
      break;
    }
    case complexType: {
      if (v.isComplex()) return part.real == v.part.real && part.imag == v.part.imag;
    }
    case stringType: {
      if (v.type == stringType) return getString() == v.getString();
      break;
    }
  }
  return 0;
}

Value Value::operator -() const {
  Value returnValue;
  switch (type) {
    default: throw ErrorMessage(opUndefinedMsg);
    case integerType:
    case realType:    returnValue =  Value(-value, type);
      break;
    case complexType: returnValue =  Value(-part.real, -part.imag);
      break;
  }
  return returnValue;
}

// Postincrement operator
Value Value::operator ++(int) {
  Value result;
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return (*pointer)++;
    case integerType: {
      result = Value(value++);
      break;
    }
  }
  return result;
}

// Postdecrement operator
Value Value::operator --(int) {
  Value result;
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return (*pointer)--;
    case integerType: {
      result = Value(value--);
      break;
    }
  }
  return result;
}

// Preincreement operator
Value Value::operator ++() {
  Value result;
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return ++(*pointer);
    case integerType: {
      result = Value(++value);
      break;
    }
  }
  return result;
  }

// Predecrement operator
Value Value::operator --() {
  Value result;
  switch (type) {
    default: throw ErrorMessage(notIntegerMsg);
    case pointerType: return --(*pointer);
    case integerType: {
      result = Value(--value);
      break;
    }
  }
  return result;
}

/*** Dataset class **********************************************************/

// Construct a Dataset
Dataset::Dataset(AgDictionary<AgString> &d)
  : dictionary(d),
    data()
{
  // Initialize value of i
  value("i") = Value(0.0, 1.0);
}

// Copy constructor for Dataset
Dataset::Dataset(const Dataset &d)
  : dictionary(d.dictionary),
    data(d.data)
{}

// Access the value of a variable by index
Value &Dataset::operator[] (int x) {
  while (data.size() < dictionary.size()) data.push(Value());
  return data[x];
}

// Access the value of a variable by name.
Value &Dataset::value(const AgString &n) {
  return operator [] (dictionary.intern(n));
}

/*** function objects and the function table ********************************/
void FunctionObject::notDefined() {
  throw ErrorMessage(wrongCountMsg);
}

// natural logarithm function
static struct LogObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return Value(log(x.magnitude()),atan2(x.getImag(), x.getReal()));
  }
} logObject;

// exponential function
static struct ExpObject : public FunctionObject {
  Value oneArg(const Value &x) {
    double r = exp(x.getReal());
    return Value(r*cos(x.getImag()), r*sin(x.getImag()));
  }
} expObject;

// square root function
static struct SqrtObject : public FunctionObject {
  Value oneArg(const Value &x) {
    double r = sqrt(x.magnitude());
    double theta = atan2(x.getImag(), x.getReal())/2;
    return Value(r*cos(theta), r*sin(theta));
  }
} sqrtObject;

// sine function
static struct SinObject : public FunctionObject {
  Value oneArg(const Value &x) {
    Value ix(-x.getImag(), x.getReal());
    return (expObject.oneArg(ix) - expObject.oneArg(-ix))/Value(0.0,2.0);
  }
} sinObject;

// arc sin function
static struct AsinObject : public FunctionObject {
  Value oneArg(const Value &x) {
    Value ix(-x.getImag(), x.getReal());
    return logObject.oneArg(ix + sqrtObject.oneArg(Value(1) - x*x))/Value(0.0,1.0);
  }
} asinObject;

// cosine function
static struct CosObject : public FunctionObject {
  Value oneArg(const Value &x) {
    Value ix(-x.getImag(), x.getReal());
    return (expObject.oneArg(ix) + expObject.oneArg(-ix))/Value(2);
  }
} cosObject;

// arc cosine function
static struct AcosObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return logObject.oneArg(x+sqrtObject.oneArg(x*x - 1))/Value(0.0,1.0);
  }
} acosObject;

// tangent function
static struct TanObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return sinObject.oneArg(x)/cosObject.oneArg(x);
  }
} tanObject;

// arc tangent function, takes one or two arguments
static struct AtanObject : public FunctionObject {
  Value oneArg(const Value &x) {
    Value ix(-x.getImag(), x.getReal());
    return logObject.oneArg((Value(1) - ix)/(Value(1) + ix)) * Value(0, .5);
  }
  Value twoArgs(const Value &x, const Value &y) {
    Value z = x/y;
    Value iz(-z.getImag(), z.getReal());
    return logObject.oneArg((Value(1) - iz)/(Value(1) + iz)) * Value(0, .5);
  }
} atanObject;

// Added function objects for complex arithmetic
static struct RealObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return Value(x.getReal());
  }
} realObject;

static struct ImagObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return Value(x.getImag());
  }
} imagObject;

static struct ConjObject : public FunctionObject {
  Value oneArg(const Value &x) {
    return Value(x.getReal(), -x.getImag());
  }
} conjObject;


// function table
const FunctionDescriptor functionTable[] = {
  {"log", 1, &logObject},
  {"exp", 1, &expObject},
  {"sqrt", 1, &sqrtObject},
  {"sin", 1, &sinObject},
  {"asin", 1, &asinObject},
  {"cos", 1, &cosObject},
  {"acos", 1, &acosObject},
  {"tan", 1, &tanObject},
  {"atan", 1, &atanObject},
  {"atan", 2, &atanObject},
  {"real", 1, &realObject},
  {"imag", 1, &imagObject},
  {"conj", 1, &conjObject},
  {0,0,0}
};

int idFunction(const AgString &name, int argCount) {
  int i;
  for (i = 0; functionTable[i].name; i++) {
    const FunctionDescriptor &f = functionTable[i];
    if (name == f.name && argCount == f.argCount) break;
  }
  if (functionTable[i].name == NULL) throw ErrorMessage(notSupportedMsg);
  return i;
}

Value callFunction(const AgString &name, AgStack<Value> &args) {
  return callFunction(idFunction(name, args.size()), args);
}

Value callFunction(int i, AgStack<Value> &args) {
  const FunctionDescriptor &f = functionTable[i];
  Value a0, a1;
  switch (f.argCount) {
    case 1:
      a0 = args.pop();
      return f.function->oneArg(a0);
    case 2:
      a1 = args.pop();
      a0 = args.pop();
      return f.function->twoArgs(a0, a1);
  }
  return Value();
}

/*** External Functions *****************************************************/

Value pow(const Value &base, const Value &exponent) {
  if (base.magnitudeSquared() == 0) return Value(0);
  Value logBase = logObject.oneArg(base);
  Value logResult = exponent * logBase;
  return expObject.oneArg(logResult);
}

// What was thought to be an octal number turns out to be decimal.
// Undo the octal representation and make it decimal.
long makeDecimal(long octal) {
  int digits[20];
  int k = 0;
  // Peel off octal digits
  do digits[k++] = octal/8; while ((octal /= 8) != 0);
  int decimal = 0;
  // Reconvert as decimal
  while (k--) decimal = 10*decimal + digits[k];
  return decimal;
}

// Added functions to support complex arithmetic
int Value::isComplex() const {
  switch (type) {
    case integerType:
    case realType:
    case complexType: return 1;
  }
  return 0;
}

void Value::assertComplex() const {
  if (isComplex()) return;
  throw ErrorMessage(notComplexMsg);
}

double Value::magnitudeSquared() const {
  assertComplex();
  return part.real*part.real + part.imag*part.imag;
}

double Value::magnitude() const {
  return sqrt(magnitudeSquared());
}

